From d4f8dea33fdfc83f760b0809ea1da6dd7bf31f45 Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Tue, 5 Oct 2021 16:35:41 +0100 Subject: [PATCH 01/41] Refactor things from `test_type_promotion.py` to `dtype_helpers.py` --- array_api_tests/array_helpers.py | 4 +- array_api_tests/dtype_helpers.py | 340 ++++++++++++++++++ array_api_tests/hypothesis_helpers.py | 34 +- .../meta_tests/test_array_helpers.py | 2 +- array_api_tests/test_broadcasting.py | 2 +- array_api_tests/test_elementwise_functions.py | 7 +- array_api_tests/test_signatures.py | 2 +- array_api_tests/test_type_promotion.py | 340 +----------------- 8 files changed, 375 insertions(+), 356 deletions(-) create mode 100644 array_api_tests/dtype_helpers.py diff --git a/array_api_tests/array_helpers.py b/array_api_tests/array_helpers.py index b6bc2c95..95eedb16 100644 --- a/array_api_tests/array_helpers.py +++ b/array_api_tests/array_helpers.py @@ -9,11 +9,13 @@ _numeric_dtypes, _boolean_dtypes, _dtypes, asarray) from . import _array_module +from .dtype_helpers import dtype_mapping, promotion_table # These are exported here so that they can be included in the special cases # tests from this file. from ._array_module import logical_not, subtract, floor, ceil, where + __all__ = ['all', 'any', 'logical_and', 'logical_or', 'logical_not', 'less', 'less_equal', 'greater', 'subtract', 'negative', 'floor', 'ceil', 'where', 'isfinite', 'equal', 'not_equal', 'zero', 'one', 'NaN', @@ -369,8 +371,6 @@ def promote_dtypes(dtype1, dtype2): Special case of result_type() which uses the exact type promotion table from the spec. """ - from .test_type_promotion import dtype_mapping, promotion_table - # Equivalent to this, but some libraries may not work properly with using # dtype objects as dict keys # diff --git a/array_api_tests/dtype_helpers.py b/array_api_tests/dtype_helpers.py new file mode 100644 index 00000000..861a46cd --- /dev/null +++ b/array_api_tests/dtype_helpers.py @@ -0,0 +1,340 @@ +from . import _array_module as xp + +__all__ = [ + "dtype_mapping", + "promotion_table", + "dtype_nbits", + "dtype_signed", + "input_types", + "dtypes_to_scalars", + "elementwise_function_input_types", + "elementwise_function_output_types", + "binary_operators", + "unary_operators", + "operators_to_functions", +] + +dtype_mapping = { + 'i1': xp.int8, + 'i2': xp.int16, + 'i4': xp.int32, + 'i8': xp.int64, + 'u1': xp.uint8, + 'u2': xp.uint16, + 'u4': xp.uint32, + 'u8': xp.uint64, + 'f4': xp.float32, + 'f8': xp.float64, + 'b': xp.bool, +} + +reverse_dtype_mapping = {v: k for k, v in dtype_mapping.items()} + +def dtype_nbits(dtype): + if dtype == xp.int8: + return 8 + elif dtype == xp.int16: + return 16 + elif dtype == xp.int32: + return 32 + elif dtype == xp.int64: + return 64 + elif dtype == xp.uint8: + return 8 + elif dtype == xp.uint16: + return 16 + elif dtype == xp.uint32: + return 32 + elif dtype == xp.uint64: + return 64 + elif dtype == xp.float32: + return 32 + elif dtype == xp.float64: + return 64 + else: + raise ValueError(f"dtype_nbits is not defined for {dtype}") + +def dtype_signed(dtype): + if dtype in [xp.int8, xp.int16, xp.int32, xp.int64]: + return True + elif dtype in [xp.uint8, xp.uint16, xp.uint32, xp.uint64]: + return False + raise ValueError("dtype_signed is only defined for integer dtypes") + +signed_integer_promotion_table = { + ('i1', 'i1'): 'i1', + ('i1', 'i2'): 'i2', + ('i1', 'i4'): 'i4', + ('i1', 'i8'): 'i8', + ('i2', 'i1'): 'i2', + ('i2', 'i2'): 'i2', + ('i2', 'i4'): 'i4', + ('i2', 'i8'): 'i8', + ('i4', 'i1'): 'i4', + ('i4', 'i2'): 'i4', + ('i4', 'i4'): 'i4', + ('i4', 'i8'): 'i8', + ('i8', 'i1'): 'i8', + ('i8', 'i2'): 'i8', + ('i8', 'i4'): 'i8', + ('i8', 'i8'): 'i8', +} + +unsigned_integer_promotion_table = { + ('u1', 'u1'): 'u1', + ('u1', 'u2'): 'u2', + ('u1', 'u4'): 'u4', + ('u1', 'u8'): 'u8', + ('u2', 'u1'): 'u2', + ('u2', 'u2'): 'u2', + ('u2', 'u4'): 'u4', + ('u2', 'u8'): 'u8', + ('u4', 'u1'): 'u4', + ('u4', 'u2'): 'u4', + ('u4', 'u4'): 'u4', + ('u4', 'u8'): 'u8', + ('u8', 'u1'): 'u8', + ('u8', 'u2'): 'u8', + ('u8', 'u4'): 'u8', + ('u8', 'u8'): 'u8', +} + +mixed_signed_unsigned_promotion_table = { + ('i1', 'u1'): 'i2', + ('i1', 'u2'): 'i4', + ('i1', 'u4'): 'i8', + ('i2', 'u1'): 'i2', + ('i2', 'u2'): 'i4', + ('i2', 'u4'): 'i8', + ('i4', 'u1'): 'i4', + ('i4', 'u2'): 'i4', + ('i4', 'u4'): 'i8', + ('i8', 'u1'): 'i8', + ('i8', 'u2'): 'i8', + ('i8', 'u4'): 'i8', +} + +flipped_mixed_signed_unsigned_promotion_table = {(u, i): p for (i, u), p in mixed_signed_unsigned_promotion_table.items()} + +float_promotion_table = { + ('f4', 'f4'): 'f4', + ('f4', 'f8'): 'f8', + ('f8', 'f4'): 'f8', + ('f8', 'f8'): 'f8', +} + +boolean_promotion_table = { + ('b', 'b'): 'b', +} + +promotion_table = { + **signed_integer_promotion_table, + **unsigned_integer_promotion_table, + **mixed_signed_unsigned_promotion_table, + **flipped_mixed_signed_unsigned_promotion_table, + **float_promotion_table, + **boolean_promotion_table, +} + +input_types = { + 'any': sorted(set(promotion_table.values())), + 'boolean': sorted(set(boolean_promotion_table.values())), + 'floating': sorted(set(float_promotion_table.values())), + 'integer': sorted(set({**signed_integer_promotion_table, + **unsigned_integer_promotion_table}.values())), + 'integer_or_boolean': sorted(set({**signed_integer_promotion_table, + **unsigned_integer_promotion_table, + **boolean_promotion_table}.values())), + 'numeric': sorted(set({**float_promotion_table, + **signed_integer_promotion_table, + **unsigned_integer_promotion_table}.values())), +} + +dtypes_to_scalars = { + 'b': [bool], + 'i1': [int], + 'i2': [int], + 'i4': [int], + 'i8': [int], + # Note: unsigned int dtypes only correspond to positive integers + 'u1': [int], + 'u2': [int], + 'u4': [int], + 'u8': [int], + 'f4': [int, float], + 'f8': [int, float], +} + +elementwise_function_input_types = { + 'abs': 'numeric', + 'acos': 'floating', + 'acosh': 'floating', + 'add': 'numeric', + 'asin': 'floating', + 'asinh': 'floating', + 'atan': 'floating', + 'atan2': 'floating', + 'atanh': 'floating', + 'bitwise_and': 'integer_or_boolean', + 'bitwise_invert': 'integer_or_boolean', + 'bitwise_left_shift': 'integer', + 'bitwise_or': 'integer_or_boolean', + 'bitwise_right_shift': 'integer', + 'bitwise_xor': 'integer_or_boolean', + 'ceil': 'numeric', + 'cos': 'floating', + 'cosh': 'floating', + 'divide': 'floating', + 'equal': 'any', + 'exp': 'floating', + 'expm1': 'floating', + 'floor': 'numeric', + 'floor_divide': 'numeric', + 'greater': 'numeric', + 'greater_equal': 'numeric', + 'isfinite': 'numeric', + 'isinf': 'numeric', + 'isnan': 'numeric', + 'less': 'numeric', + 'less_equal': 'numeric', + 'log': 'floating', + 'logaddexp': 'floating', + 'log10': 'floating', + 'log1p': 'floating', + 'log2': 'floating', + 'logical_and': 'boolean', + 'logical_not': 'boolean', + 'logical_or': 'boolean', + 'logical_xor': 'boolean', + 'multiply': 'numeric', + 'negative': 'numeric', + 'not_equal': 'any', + 'positive': 'numeric', + 'pow': 'floating', + 'remainder': 'numeric', + 'round': 'numeric', + 'sign': 'numeric', + 'sin': 'floating', + 'sinh': 'floating', + 'sqrt': 'floating', + 'square': 'numeric', + 'subtract': 'numeric', + 'tan': 'floating', + 'tanh': 'floating', + 'trunc': 'numeric', +} + +elementwise_function_output_types = { + 'abs': 'promoted', + 'acos': 'promoted', + 'acosh': 'promoted', + 'add': 'promoted', + 'asin': 'promoted', + 'asinh': 'promoted', + 'atan': 'promoted', + 'atan2': 'promoted', + 'atanh': 'promoted', + 'bitwise_and': 'promoted', + 'bitwise_invert': 'promoted', + 'bitwise_left_shift': 'promoted', + 'bitwise_or': 'promoted', + 'bitwise_right_shift': 'promoted', + 'bitwise_xor': 'promoted', + 'ceil': 'promoted', + 'cos': 'promoted', + 'cosh': 'promoted', + 'divide': 'promoted', + 'equal': 'bool', + 'exp': 'promoted', + 'expm1': 'promoted', + 'floor': 'promoted', + 'floor_divide': 'promoted', + 'greater': 'bool', + 'greater_equal': 'bool', + 'isfinite': 'bool', + 'isinf': 'bool', + 'isnan': 'bool', + 'less': 'bool', + 'less_equal': 'bool', + 'log': 'promoted', + 'logaddexp': 'promoted', + 'log10': 'promoted', + 'log1p': 'promoted', + 'log2': 'promoted', + 'logical_and': 'bool', + 'logical_not': 'bool', + 'logical_or': 'bool', + 'logical_xor': 'bool', + 'multiply': 'promoted', + 'negative': 'promoted', + 'not_equal': 'bool', + 'positive': 'promoted', + 'pow': 'promoted', + 'remainder': 'promoted', + 'round': 'promoted', + 'sign': 'promoted', + 'sin': 'promoted', + 'sinh': 'promoted', + 'sqrt': 'promoted', + 'square': 'promoted', + 'subtract': 'promoted', + 'tan': 'promoted', + 'tanh': 'promoted', + 'trunc': 'promoted', +} + +binary_operators = { + '__add__': '+', + '__and__': '&', + '__eq__': '==', + '__floordiv__': '//', + '__ge__': '>=', + '__gt__': '>', + '__le__': '<=', + '__lshift__': '<<', + '__lt__': '<', + '__matmul__': '@', + '__mod__': '%', + '__mul__': '*', + '__ne__': '!=', + '__or__': '|', + '__pow__': '**', + '__rshift__': '>>', + '__sub__': '-', + '__truediv__': '/', + '__xor__': '^', +} + +unary_operators = { + '__abs__': 'abs()', + '__invert__': '~', + '__neg__': '-', + '__pos__': '+', +} + + +operators_to_functions = { + '__abs__': 'abs', + '__add__': 'add', + '__and__': 'bitwise_and', + '__eq__': 'equal', + '__floordiv__': 'floor_divide', + '__ge__': 'greater_equal', + '__gt__': 'greater', + '__le__': 'less_equal', + '__lshift__': 'bitwise_left_shift', + '__lt__': 'less', + '__matmul__': 'matmul', + '__mod__': 'remainder', + '__mul__': 'multiply', + '__ne__': 'not_equal', + '__or__': 'bitwise_or', + '__pow__': 'pow', + '__rshift__': 'bitwise_right_shift', + '__sub__': 'subtract', + '__truediv__': 'divide', + '__xor__': 'bitwise_xor', + '__invert__': 'bitwise_invert', + '__neg__': 'negative', + '__pos__': 'positive', +} diff --git a/array_api_tests/hypothesis_helpers.py b/array_api_tests/hypothesis_helpers.py index 91e9767a..71e2859c 100644 --- a/array_api_tests/hypothesis_helpers.py +++ b/array_api_tests/hypothesis_helpers.py @@ -14,6 +14,7 @@ boolean_dtype_objects, integer_or_boolean_dtype_objects, dtype_objects) from ._array_module import full, float32, float64, bool as bool_dtype, _UndefinedStub +from .dtype_helpers import dtype_mapping, promotion_table from . import _array_module from . import _array_module as xp @@ -48,14 +49,19 @@ shared_dtypes = shared(dtypes, key="dtype") -# TODO: Importing things from test_type_promotion should be replaced by -# something that won't cause a circular import. Right now we use @st.composite -# only because it returns a lazy-evaluated strategy - in the future this method -# should remove the composite wrapper, just returning sampled_from(dtype_pairs) -# instead of drawing from it. -@composite -def mutually_promotable_dtypes(draw, dtype_objects=dtype_objects): - from .test_type_promotion import dtype_mapping, promotion_table + +sorted_table = sorted(promotion_table) +sorted_table = sorted( + sorted_table, key=lambda ij: -1 if ij[0] == ij[1] else sorted_table.index(ij) +) +dtype_pairs = [(dtype_mapping[i], dtype_mapping[j]) for i, j in sorted_table] +if FILTER_UNDEFINED_DTYPES: + dtype_pairs = [(i, j) for i, j in dtype_pairs + if not isinstance(i, _UndefinedStub) + and not isinstance(j, _UndefinedStub)] + + +def mutually_promotable_dtypes(dtype_objects=dtype_objects): # sort for shrinking (sampled_from shrinks to the earlier elements in the # list). Give pairs of the same dtypes first, then smaller dtypes, # preferring float, then int, then unsigned int. Note, this might not @@ -63,17 +69,9 @@ def mutually_promotable_dtypes(draw, dtype_objects=dtype_objects): # that draw from dtypes might not draw the same example from different # pairs (XXX: Can we redesign the strategies so that they can prefer # shrinking dtypes over values?) - sorted_table = sorted(promotion_table) - sorted_table = sorted( - sorted_table, key=lambda ij: -1 if ij[0] == ij[1] else sorted_table.index(ij) + return sampled_from( + [(i, j) for i, j in dtype_pairs if i in dtype_objects and j in dtype_objects] ) - dtype_pairs = [(dtype_mapping[i], dtype_mapping[j]) for i, j in sorted_table] - if FILTER_UNDEFINED_DTYPES: - dtype_pairs = [(i, j) for i, j in dtype_pairs - if not isinstance(i, _UndefinedStub) - and not isinstance(j, _UndefinedStub)] - dtype_pairs = [(i, j) for i, j in dtype_pairs if i in dtype_objects and j in dtype_objects] - return draw(sampled_from(dtype_pairs)) shared_mutually_promotable_dtype_pairs = shared( mutually_promotable_dtypes(), key="mutually_promotable_dtype_pair" diff --git a/array_api_tests/meta_tests/test_array_helpers.py b/array_api_tests/meta_tests/test_array_helpers.py index 7d1d3b3c..e4a78248 100644 --- a/array_api_tests/meta_tests/test_array_helpers.py +++ b/array_api_tests/meta_tests/test_array_helpers.py @@ -1,6 +1,6 @@ from ..array_helpers import exactly_equal, notequal, int_to_dtype from ..hypothesis_helpers import integer_dtypes -from ..test_type_promotion import dtype_nbits, dtype_signed +from ..dtype_helpers import dtype_nbits, dtype_signed from .._array_module import asarray, nan, equal, all from hypothesis import given, assume diff --git a/array_api_tests/test_broadcasting.py b/array_api_tests/test_broadcasting.py index 0c43f79e..69209075 100644 --- a/array_api_tests/test_broadcasting.py +++ b/array_api_tests/test_broadcasting.py @@ -10,7 +10,7 @@ from .hypothesis_helpers import shapes, FILTER_UNDEFINED_DTYPES from .pytest_helpers import raises, doesnt_raise, nargs -from .test_type_promotion import (elementwise_function_input_types, +from .dtype_helpers import (elementwise_function_input_types, input_types, dtype_mapping) from .function_stubs import elementwise_functions from . import _array_module diff --git a/array_api_tests/test_elementwise_functions.py b/array_api_tests/test_elementwise_functions.py index 01db8882..4bae2b29 100644 --- a/array_api_tests/test_elementwise_functions.py +++ b/array_api_tests/test_elementwise_functions.py @@ -38,6 +38,7 @@ is_float_dtype, not_equal, float64, asarray, dtype_ranges, full, true, false, assert_same_sign, isnan, less) +from .dtype_helpers import dtype_nbits, dtype_signed # We might as well use this implementation rather than requiring # mod.broadcast_shapes(). See test_equal() and others. from .test_broadcasting import broadcast_shapes @@ -205,7 +206,6 @@ def test_atanh(x): @given(two_integer_or_boolean_dtypes.flatmap(lambda i: two_array_scalars(*i))) def test_bitwise_and(args): - from .test_type_promotion import dtype_nbits, dtype_signed x1, x2 = args sanity_check(x1, x2) a = _array_module.bitwise_and(x1, x2) @@ -228,7 +228,6 @@ def test_bitwise_and(args): @given(two_integer_dtypes.flatmap(lambda i: two_array_scalars(*i))) def test_bitwise_left_shift(args): - from .test_type_promotion import dtype_nbits, dtype_signed x1, x2 = args sanity_check(x1, x2) negative_x2 = isnegative(x2) @@ -252,7 +251,6 @@ def test_bitwise_left_shift(args): @given(integer_or_boolean_scalars) def test_bitwise_invert(x): - from .test_type_promotion import dtype_nbits, dtype_signed a = _array_module.bitwise_invert(x) # Compare against the Python ~ operator. # TODO: Generalize this properly for inputs that are arrays. @@ -270,7 +268,6 @@ def test_bitwise_invert(x): @given(two_integer_or_boolean_dtypes.flatmap(lambda i: two_array_scalars(*i))) def test_bitwise_or(args): - from .test_type_promotion import dtype_nbits, dtype_signed x1, x2 = args sanity_check(x1, x2) a = _array_module.bitwise_or(x1, x2) @@ -292,7 +289,6 @@ def test_bitwise_or(args): @given(two_integer_dtypes.flatmap(lambda i: two_array_scalars(*i))) def test_bitwise_right_shift(args): - from .test_type_promotion import dtype_nbits, dtype_signed x1, x2 = args sanity_check(x1, x2) negative_x2 = isnegative(x2) @@ -311,7 +307,6 @@ def test_bitwise_right_shift(args): @given(two_integer_or_boolean_dtypes.flatmap(lambda i: two_array_scalars(*i))) def test_bitwise_xor(args): - from .test_type_promotion import dtype_nbits, dtype_signed x1, x2 = args sanity_check(x1, x2) a = _array_module.bitwise_xor(x1, x2) diff --git a/array_api_tests/test_signatures.py b/array_api_tests/test_signatures.py index f2a28e13..237a7446 100644 --- a/array_api_tests/test_signatures.py +++ b/array_api_tests/test_signatures.py @@ -4,7 +4,7 @@ from ._array_module import mod, mod_name, ones, eye, float64, bool, int64 from .pytest_helpers import raises, doesnt_raise -from .test_type_promotion import elementwise_function_input_types, operators_to_functions +from .dtype_helpers import elementwise_function_input_types, operators_to_functions from . import function_stubs diff --git a/array_api_tests/test_type_promotion.py b/array_api_tests/test_type_promotion.py index 50ba2e77..98e1850f 100644 --- a/array_api_tests/test_type_promotion.py +++ b/array_api_tests/test_type_promotion.py @@ -13,335 +13,21 @@ from .array_helpers import assert_exactly_equal, dtype_ranges from .function_stubs import elementwise_functions -from ._array_module import (full, int8, int16, int32, int64, uint8, - uint16, uint32, uint64, float32, float64, bool as - bool_dtype) +from ._array_module import (full, bool as bool_dtype) from . import _array_module +from .dtype_helpers import ( + dtype_mapping, + promotion_table, + input_types, + dtypes_to_scalars, + elementwise_function_input_types, + elementwise_function_output_types, + binary_operators, + unary_operators, + operators_to_functions, +) + -dtype_mapping = { - 'i1': int8, - 'i2': int16, - 'i4': int32, - 'i8': int64, - 'u1': uint8, - 'u2': uint16, - 'u4': uint32, - 'u8': uint64, - 'f4': float32, - 'f8': float64, - 'b': bool_dtype, -} - -reverse_dtype_mapping = {v: k for k, v in dtype_mapping.items()} - -def dtype_nbits(dtype): - if dtype == int8: - return 8 - elif dtype == int16: - return 16 - elif dtype == int32: - return 32 - elif dtype == int64: - return 64 - elif dtype == uint8: - return 8 - elif dtype == uint16: - return 16 - elif dtype == uint32: - return 32 - elif dtype == uint64: - return 64 - elif dtype == float32: - return 32 - elif dtype == float64: - return 64 - else: - raise ValueError(f"dtype_nbits is not defined for {dtype}") - -def dtype_signed(dtype): - if dtype in [int8, int16, int32, int64]: - return True - elif dtype in [uint8, uint16, uint32, uint64]: - return False - raise ValueError("dtype_signed is only defined for integer dtypes") - -signed_integer_promotion_table = { - ('i1', 'i1'): 'i1', - ('i1', 'i2'): 'i2', - ('i1', 'i4'): 'i4', - ('i1', 'i8'): 'i8', - ('i2', 'i1'): 'i2', - ('i2', 'i2'): 'i2', - ('i2', 'i4'): 'i4', - ('i2', 'i8'): 'i8', - ('i4', 'i1'): 'i4', - ('i4', 'i2'): 'i4', - ('i4', 'i4'): 'i4', - ('i4', 'i8'): 'i8', - ('i8', 'i1'): 'i8', - ('i8', 'i2'): 'i8', - ('i8', 'i4'): 'i8', - ('i8', 'i8'): 'i8', -} - -unsigned_integer_promotion_table = { - ('u1', 'u1'): 'u1', - ('u1', 'u2'): 'u2', - ('u1', 'u4'): 'u4', - ('u1', 'u8'): 'u8', - ('u2', 'u1'): 'u2', - ('u2', 'u2'): 'u2', - ('u2', 'u4'): 'u4', - ('u2', 'u8'): 'u8', - ('u4', 'u1'): 'u4', - ('u4', 'u2'): 'u4', - ('u4', 'u4'): 'u4', - ('u4', 'u8'): 'u8', - ('u8', 'u1'): 'u8', - ('u8', 'u2'): 'u8', - ('u8', 'u4'): 'u8', - ('u8', 'u8'): 'u8', -} - -mixed_signed_unsigned_promotion_table = { - ('i1', 'u1'): 'i2', - ('i1', 'u2'): 'i4', - ('i1', 'u4'): 'i8', - ('i2', 'u1'): 'i2', - ('i2', 'u2'): 'i4', - ('i2', 'u4'): 'i8', - ('i4', 'u1'): 'i4', - ('i4', 'u2'): 'i4', - ('i4', 'u4'): 'i8', - ('i8', 'u1'): 'i8', - ('i8', 'u2'): 'i8', - ('i8', 'u4'): 'i8', -} - -flipped_mixed_signed_unsigned_promotion_table = {(u, i): p for (i, u), p in mixed_signed_unsigned_promotion_table.items()} - -float_promotion_table = { - ('f4', 'f4'): 'f4', - ('f4', 'f8'): 'f8', - ('f8', 'f4'): 'f8', - ('f8', 'f8'): 'f8', -} - -boolean_promotion_table = { - ('b', 'b'): 'b', -} - -promotion_table = { - **signed_integer_promotion_table, - **unsigned_integer_promotion_table, - **mixed_signed_unsigned_promotion_table, - **flipped_mixed_signed_unsigned_promotion_table, - **float_promotion_table, - **boolean_promotion_table, -} - -input_types = { - 'any': sorted(set(promotion_table.values())), - 'boolean': sorted(set(boolean_promotion_table.values())), - 'floating': sorted(set(float_promotion_table.values())), - 'integer': sorted(set({**signed_integer_promotion_table, - **unsigned_integer_promotion_table}.values())), - 'integer_or_boolean': sorted(set({**signed_integer_promotion_table, - **unsigned_integer_promotion_table, - **boolean_promotion_table}.values())), - 'numeric': sorted(set({**float_promotion_table, - **signed_integer_promotion_table, - **unsigned_integer_promotion_table}.values())), -} - -binary_operators = { - '__add__': '+', - '__and__': '&', - '__eq__': '==', - '__floordiv__': '//', - '__ge__': '>=', - '__gt__': '>', - '__le__': '<=', - '__lshift__': '<<', - '__lt__': '<', - '__matmul__': '@', - '__mod__': '%', - '__mul__': '*', - '__ne__': '!=', - '__or__': '|', - '__pow__': '**', - '__rshift__': '>>', - '__sub__': '-', - '__truediv__': '/', - '__xor__': '^', -} - -unary_operators = { - '__abs__': 'abs()', - '__invert__': '~', - '__neg__': '-', - '__pos__': '+', -} - - -operators_to_functions = { - '__abs__': 'abs', - '__add__': 'add', - '__and__': 'bitwise_and', - '__eq__': 'equal', - '__floordiv__': 'floor_divide', - '__ge__': 'greater_equal', - '__gt__': 'greater', - '__le__': 'less_equal', - '__lshift__': 'bitwise_left_shift', - '__lt__': 'less', - '__matmul__': 'matmul', - '__mod__': 'remainder', - '__mul__': 'multiply', - '__ne__': 'not_equal', - '__or__': 'bitwise_or', - '__pow__': 'pow', - '__rshift__': 'bitwise_right_shift', - '__sub__': 'subtract', - '__truediv__': 'divide', - '__xor__': 'bitwise_xor', - '__invert__': 'bitwise_invert', - '__neg__': 'negative', - '__pos__': 'positive', -} - -dtypes_to_scalars = { - 'b': [bool], - 'i1': [int], - 'i2': [int], - 'i4': [int], - 'i8': [int], - # Note: unsigned int dtypes only correspond to positive integers - 'u1': [int], - 'u2': [int], - 'u4': [int], - 'u8': [int], - 'f4': [int, float], - 'f8': [int, float], -} - -elementwise_function_input_types = { - 'abs': 'numeric', - 'acos': 'floating', - 'acosh': 'floating', - 'add': 'numeric', - 'asin': 'floating', - 'asinh': 'floating', - 'atan': 'floating', - 'atan2': 'floating', - 'atanh': 'floating', - 'bitwise_and': 'integer_or_boolean', - 'bitwise_invert': 'integer_or_boolean', - 'bitwise_left_shift': 'integer', - 'bitwise_or': 'integer_or_boolean', - 'bitwise_right_shift': 'integer', - 'bitwise_xor': 'integer_or_boolean', - 'ceil': 'numeric', - 'cos': 'floating', - 'cosh': 'floating', - 'divide': 'floating', - 'equal': 'any', - 'exp': 'floating', - 'expm1': 'floating', - 'floor': 'numeric', - 'floor_divide': 'numeric', - 'greater': 'numeric', - 'greater_equal': 'numeric', - 'isfinite': 'numeric', - 'isinf': 'numeric', - 'isnan': 'numeric', - 'less': 'numeric', - 'less_equal': 'numeric', - 'log': 'floating', - 'logaddexp': 'floating', - 'log10': 'floating', - 'log1p': 'floating', - 'log2': 'floating', - 'logical_and': 'boolean', - 'logical_not': 'boolean', - 'logical_or': 'boolean', - 'logical_xor': 'boolean', - 'multiply': 'numeric', - 'negative': 'numeric', - 'not_equal': 'any', - 'positive': 'numeric', - 'pow': 'floating', - 'remainder': 'numeric', - 'round': 'numeric', - 'sign': 'numeric', - 'sin': 'floating', - 'sinh': 'floating', - 'sqrt': 'floating', - 'square': 'numeric', - 'subtract': 'numeric', - 'tan': 'floating', - 'tanh': 'floating', - 'trunc': 'numeric', -} - -elementwise_function_output_types = { - 'abs': 'promoted', - 'acos': 'promoted', - 'acosh': 'promoted', - 'add': 'promoted', - 'asin': 'promoted', - 'asinh': 'promoted', - 'atan': 'promoted', - 'atan2': 'promoted', - 'atanh': 'promoted', - 'bitwise_and': 'promoted', - 'bitwise_invert': 'promoted', - 'bitwise_left_shift': 'promoted', - 'bitwise_or': 'promoted', - 'bitwise_right_shift': 'promoted', - 'bitwise_xor': 'promoted', - 'ceil': 'promoted', - 'cos': 'promoted', - 'cosh': 'promoted', - 'divide': 'promoted', - 'equal': 'bool', - 'exp': 'promoted', - 'expm1': 'promoted', - 'floor': 'promoted', - 'floor_divide': 'promoted', - 'greater': 'bool', - 'greater_equal': 'bool', - 'isfinite': 'bool', - 'isinf': 'bool', - 'isnan': 'bool', - 'less': 'bool', - 'less_equal': 'bool', - 'log': 'promoted', - 'logaddexp': 'promoted', - 'log10': 'promoted', - 'log1p': 'promoted', - 'log2': 'promoted', - 'logical_and': 'bool', - 'logical_not': 'bool', - 'logical_or': 'bool', - 'logical_xor': 'bool', - 'multiply': 'promoted', - 'negative': 'promoted', - 'not_equal': 'bool', - 'positive': 'promoted', - 'pow': 'promoted', - 'remainder': 'promoted', - 'round': 'promoted', - 'sign': 'promoted', - 'sin': 'promoted', - 'sinh': 'promoted', - 'sqrt': 'promoted', - 'square': 'promoted', - 'subtract': 'promoted', - 'tan': 'promoted', - 'tanh': 'promoted', - 'trunc': 'promoted', -} elementwise_function_two_arg_func_names = [func_name for func_name in elementwise_functions.__all__ if From d12808bc58f69128f2fb4065491bb67f3b44ddc8 Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Tue, 5 Oct 2021 17:02:02 +0100 Subject: [PATCH 02/41] Meta test for promote_dtype --- .../meta_tests/test_array_helpers.py | 46 +++++++++++++------ 1 file changed, 32 insertions(+), 14 deletions(-) diff --git a/array_api_tests/meta_tests/test_array_helpers.py b/array_api_tests/meta_tests/test_array_helpers.py index e4a78248..7858183d 100644 --- a/array_api_tests/meta_tests/test_array_helpers.py +++ b/array_api_tests/meta_tests/test_array_helpers.py @@ -1,33 +1,51 @@ -from ..array_helpers import exactly_equal, notequal, int_to_dtype -from ..hypothesis_helpers import integer_dtypes -from ..dtype_helpers import dtype_nbits, dtype_signed -from .._array_module import asarray, nan, equal, all - +import pytest from hypothesis import given, assume from hypothesis.strategies import integers +from ..array_helpers import exactly_equal, notequal, int_to_dtype, promote_dtypes +from ..hypothesis_helpers import integer_dtypes +from ..dtype_helpers import dtype_nbits, dtype_signed +from .. import _array_module as xp + # TODO: These meta-tests currently only work with NumPy def test_exactly_equal(): - a = asarray([0, 0., -0., -0., nan, nan, 1, 1]) - b = asarray([0, -1, -0., 0., nan, 1, 1, 2]) + a = xp.asarray([0, 0., -0., -0., xp.nan, xp.nan, 1, 1]) + b = xp.asarray([0, -1, -0., 0., xp.nan, 1, 1, 2]) - res = asarray([True, False, True, False, True, False, True, False]) - assert all(equal(exactly_equal(a, b), res)) + res = xp.asarray([True, False, True, False, True, False, True, False]) + assert xp.all(xp.equal(exactly_equal(a, b), res)) def test_notequal(): - a = asarray([0, 0., -0., -0., nan, nan, 1, 1]) - b = asarray([0, -1, -0., 0., nan, 1, 1, 2]) + a = xp.asarray([0, 0., -0., -0., xp.nan, xp.nan, 1, 1]) + b = xp.asarray([0, -1, -0., 0., xp.nan, 1, 1, 2]) - res = asarray([False, True, False, False, False, True, False, True]) - assert all(equal(notequal(a, b), res)) + res = xp.asarray([False, True, False, False, False, True, False, True]) + assert xp.all(xp.equal(notequal(a, b), res)) @given(integers(), integer_dtypes) def test_int_to_dtype(x, dtype): n = dtype_nbits(dtype) signed = dtype_signed(dtype) try: - d = asarray(x, dtype=dtype) + d = xp.asarray(x, dtype=dtype) except OverflowError: assume(False) assert int_to_dtype(x, n, signed) == d + +@pytest.mark.parametrize( + "dtype1, dtype2, result", + [ + (xp.uint8, xp.uint8, xp.uint8), + (xp.uint8, xp.int8, xp.int16), + (xp.int8, xp.int8, xp.int8), + ] +) +def test_promote_dtypes(dtype1, dtype2, result): + assert promote_dtypes(dtype1, dtype2) == result + + +@pytest.mark.parametrize("dtype1, dtype2", [(xp.uint8, xp.float32)]) +def test_promote_dtypes_incompatible_dtypes_fail(dtype1, dtype2): + with pytest.raises(ValueError): + promote_dtypes(dtype1, dtype2) From 29cec271696f92c0005d7b2dac154a372c2a32ba Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Tue, 5 Oct 2021 19:46:12 +0100 Subject: [PATCH 03/41] Use full dtype names instead of shorthands --- array_api_tests/dtype_helpers.py | 142 +++++++++++++++---------------- 1 file changed, 71 insertions(+), 71 deletions(-) diff --git a/array_api_tests/dtype_helpers.py b/array_api_tests/dtype_helpers.py index 861a46cd..20933f4e 100644 --- a/array_api_tests/dtype_helpers.py +++ b/array_api_tests/dtype_helpers.py @@ -15,17 +15,17 @@ ] dtype_mapping = { - 'i1': xp.int8, - 'i2': xp.int16, - 'i4': xp.int32, - 'i8': xp.int64, - 'u1': xp.uint8, - 'u2': xp.uint16, - 'u4': xp.uint32, - 'u8': xp.uint64, - 'f4': xp.float32, - 'f8': xp.float64, - 'b': xp.bool, + 'int8': xp.int8, + 'int16': xp.int16, + 'int32': xp.int32, + 'int64': xp.int64, + 'uint8': xp.uint8, + 'uint16': xp.uint16, + 'uint32': xp.uint32, + 'uint64': xp.uint64, + 'float32': xp.float32, + 'float64': xp.float64, + 'bool': xp.bool, } reverse_dtype_mapping = {v: k for k, v in dtype_mapping.items()} @@ -62,69 +62,69 @@ def dtype_signed(dtype): raise ValueError("dtype_signed is only defined for integer dtypes") signed_integer_promotion_table = { - ('i1', 'i1'): 'i1', - ('i1', 'i2'): 'i2', - ('i1', 'i4'): 'i4', - ('i1', 'i8'): 'i8', - ('i2', 'i1'): 'i2', - ('i2', 'i2'): 'i2', - ('i2', 'i4'): 'i4', - ('i2', 'i8'): 'i8', - ('i4', 'i1'): 'i4', - ('i4', 'i2'): 'i4', - ('i4', 'i4'): 'i4', - ('i4', 'i8'): 'i8', - ('i8', 'i1'): 'i8', - ('i8', 'i2'): 'i8', - ('i8', 'i4'): 'i8', - ('i8', 'i8'): 'i8', + ('int8', 'int8'): 'int8', + ('int8', 'int16'): 'int16', + ('int8', 'int32'): 'int32', + ('int8', 'int64'): 'int64', + ('int16', 'int8'): 'int16', + ('int16', 'int16'): 'int16', + ('int16', 'int32'): 'int32', + ('int16', 'int64'): 'int64', + ('int32', 'int8'): 'int32', + ('int32', 'int16'): 'int32', + ('int32', 'int32'): 'int32', + ('int32', 'int64'): 'int64', + ('int64', 'int8'): 'int64', + ('int64', 'int16'): 'int64', + ('int64', 'int32'): 'int64', + ('int64', 'int64'): 'int64', } unsigned_integer_promotion_table = { - ('u1', 'u1'): 'u1', - ('u1', 'u2'): 'u2', - ('u1', 'u4'): 'u4', - ('u1', 'u8'): 'u8', - ('u2', 'u1'): 'u2', - ('u2', 'u2'): 'u2', - ('u2', 'u4'): 'u4', - ('u2', 'u8'): 'u8', - ('u4', 'u1'): 'u4', - ('u4', 'u2'): 'u4', - ('u4', 'u4'): 'u4', - ('u4', 'u8'): 'u8', - ('u8', 'u1'): 'u8', - ('u8', 'u2'): 'u8', - ('u8', 'u4'): 'u8', - ('u8', 'u8'): 'u8', + ('uint8', 'uint8'): 'uint8', + ('uint8', 'uint16'): 'uint16', + ('uint8', 'uint32'): 'uint32', + ('uint8', 'uint64'): 'uint64', + ('uint16', 'uint8'): 'uint16', + ('uint16', 'uint16'): 'uint16', + ('uint16', 'uint32'): 'uint32', + ('uint16', 'uint64'): 'uint64', + ('uint32', 'uint8'): 'uint32', + ('uint32', 'uint16'): 'uint32', + ('uint32', 'uint32'): 'uint32', + ('uint32', 'uint64'): 'uint64', + ('uint64', 'uint8'): 'uint64', + ('uint64', 'uint16'): 'uint64', + ('uint64', 'uint32'): 'uint64', + ('uint64', 'uint64'): 'uint64', } mixed_signed_unsigned_promotion_table = { - ('i1', 'u1'): 'i2', - ('i1', 'u2'): 'i4', - ('i1', 'u4'): 'i8', - ('i2', 'u1'): 'i2', - ('i2', 'u2'): 'i4', - ('i2', 'u4'): 'i8', - ('i4', 'u1'): 'i4', - ('i4', 'u2'): 'i4', - ('i4', 'u4'): 'i8', - ('i8', 'u1'): 'i8', - ('i8', 'u2'): 'i8', - ('i8', 'u4'): 'i8', + ('int8', 'uint8'): 'int16', + ('int8', 'uint16'): 'int32', + ('int8', 'uint32'): 'int64', + ('int16', 'uint8'): 'int16', + ('int16', 'uint16'): 'int32', + ('int16', 'uint32'): 'int64', + ('int32', 'uint8'): 'int32', + ('int32', 'uint16'): 'int32', + ('int32', 'uint32'): 'int64', + ('int64', 'uint8'): 'int64', + ('int64', 'uint16'): 'int64', + ('int64', 'uint32'): 'int64', } flipped_mixed_signed_unsigned_promotion_table = {(u, i): p for (i, u), p in mixed_signed_unsigned_promotion_table.items()} float_promotion_table = { - ('f4', 'f4'): 'f4', - ('f4', 'f8'): 'f8', - ('f8', 'f4'): 'f8', - ('f8', 'f8'): 'f8', + ('float32', 'float32'): 'float32', + ('float32', 'float64'): 'float64', + ('float64', 'float32'): 'float64', + ('float64', 'float64'): 'float64', } boolean_promotion_table = { - ('b', 'b'): 'b', + ('bool', 'bool'): 'bool', } promotion_table = { @@ -151,18 +151,18 @@ def dtype_signed(dtype): } dtypes_to_scalars = { - 'b': [bool], - 'i1': [int], - 'i2': [int], - 'i4': [int], - 'i8': [int], + 'bool': [bool], + 'int8': [int], + 'int16': [int], + 'int32': [int], + 'int64': [int], # Note: unsigned int dtypes only correspond to positive integers - 'u1': [int], - 'u2': [int], - 'u4': [int], - 'u8': [int], - 'f4': [int, float], - 'f8': [int, float], + 'uint8': [int], + 'uint16': [int], + 'uint32': [int], + 'uint64': [int], + 'float32': [int, float], + 'float64': [int, float], } elementwise_function_input_types = { From 3f1060f2974205f181536b59bedc4ed7e60714eb Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Thu, 7 Oct 2021 12:32:41 +0100 Subject: [PATCH 04/41] Use dtype object instead of name for dtype helpers --- array_api_tests/array_helpers.py | 20 ++-- array_api_tests/dtype_helpers.py | 151 ++++++++++++------------- array_api_tests/hypothesis_helpers.py | 7 +- array_api_tests/test_broadcasting.py | 13 +-- array_api_tests/test_type_promotion.py | 69 ++++------- 5 files changed, 117 insertions(+), 143 deletions(-) diff --git a/array_api_tests/array_helpers.py b/array_api_tests/array_helpers.py index 95eedb16..1c333a92 100644 --- a/array_api_tests/array_helpers.py +++ b/array_api_tests/array_helpers.py @@ -9,7 +9,7 @@ _numeric_dtypes, _boolean_dtypes, _dtypes, asarray) from . import _array_module -from .dtype_helpers import dtype_mapping, promotion_table +from .dtype_helpers import promotion_table # These are exported here so that they can be included in the special cases # tests from this file. @@ -371,14 +371,10 @@ def promote_dtypes(dtype1, dtype2): Special case of result_type() which uses the exact type promotion table from the spec. """ - # Equivalent to this, but some libraries may not work properly with using - # dtype objects as dict keys - # - # d1, d2 = reverse_dtype_mapping[dtype1], reverse_dtype_mapping[dtype2] - - d1 = [i for i in dtype_mapping if dtype_mapping[i] == dtype1][0] - d2 = [i for i in dtype_mapping if dtype_mapping[i] == dtype2][0] - - if (d1, d2) not in promotion_table: - raise ValueError(f"{d1} and {d2} are not type promotable according to the spec (this may indicate a bug in the test suite).") - return dtype_mapping[promotion_table[d1, d2]] + try: + return promotion_table[(dtype1, dtype2)] + except KeyError as e: + raise ValueError( + f"{dtype1} and {dtype2} are not type promotable according to the spec" + f"(this may indicate a bug in the test suite)." + ) from e diff --git a/array_api_tests/dtype_helpers.py b/array_api_tests/dtype_helpers.py index 20933f4e..b955fa6e 100644 --- a/array_api_tests/dtype_helpers.py +++ b/array_api_tests/dtype_helpers.py @@ -1,7 +1,7 @@ from . import _array_module as xp + __all__ = [ - "dtype_mapping", "promotion_table", "dtype_nbits", "dtype_signed", @@ -14,21 +14,6 @@ "operators_to_functions", ] -dtype_mapping = { - 'int8': xp.int8, - 'int16': xp.int16, - 'int32': xp.int32, - 'int64': xp.int64, - 'uint8': xp.uint8, - 'uint16': xp.uint16, - 'uint32': xp.uint32, - 'uint64': xp.uint64, - 'float32': xp.float32, - 'float64': xp.float64, - 'bool': xp.bool, -} - -reverse_dtype_mapping = {v: k for k, v in dtype_mapping.items()} def dtype_nbits(dtype): if dtype == xp.int8: @@ -54,6 +39,7 @@ def dtype_nbits(dtype): else: raise ValueError(f"dtype_nbits is not defined for {dtype}") + def dtype_signed(dtype): if dtype in [xp.int8, xp.int16, xp.int32, xp.int64]: return True @@ -61,72 +47,79 @@ def dtype_signed(dtype): return False raise ValueError("dtype_signed is only defined for integer dtypes") + signed_integer_promotion_table = { - ('int8', 'int8'): 'int8', - ('int8', 'int16'): 'int16', - ('int8', 'int32'): 'int32', - ('int8', 'int64'): 'int64', - ('int16', 'int8'): 'int16', - ('int16', 'int16'): 'int16', - ('int16', 'int32'): 'int32', - ('int16', 'int64'): 'int64', - ('int32', 'int8'): 'int32', - ('int32', 'int16'): 'int32', - ('int32', 'int32'): 'int32', - ('int32', 'int64'): 'int64', - ('int64', 'int8'): 'int64', - ('int64', 'int16'): 'int64', - ('int64', 'int32'): 'int64', - ('int64', 'int64'): 'int64', + (xp.int8, xp.int8): xp.int8, + (xp.int8, xp.int16): xp.int16, + (xp.int8, xp.int32): xp.int32, + (xp.int8, xp.int64): xp.int64, + (xp.int16, xp.int8): xp.int16, + (xp.int16, xp.int16): xp.int16, + (xp.int16, xp.int32): xp.int32, + (xp.int16, xp.int64): xp.int64, + (xp.int32, xp.int8): xp.int32, + (xp.int32, xp.int16): xp.int32, + (xp.int32, xp.int32): xp.int32, + (xp.int32, xp.int64): xp.int64, + (xp.int64, xp.int8): xp.int64, + (xp.int64, xp.int16): xp.int64, + (xp.int64, xp.int32): xp.int64, + (xp.int64, xp.int64): xp.int64, } + unsigned_integer_promotion_table = { - ('uint8', 'uint8'): 'uint8', - ('uint8', 'uint16'): 'uint16', - ('uint8', 'uint32'): 'uint32', - ('uint8', 'uint64'): 'uint64', - ('uint16', 'uint8'): 'uint16', - ('uint16', 'uint16'): 'uint16', - ('uint16', 'uint32'): 'uint32', - ('uint16', 'uint64'): 'uint64', - ('uint32', 'uint8'): 'uint32', - ('uint32', 'uint16'): 'uint32', - ('uint32', 'uint32'): 'uint32', - ('uint32', 'uint64'): 'uint64', - ('uint64', 'uint8'): 'uint64', - ('uint64', 'uint16'): 'uint64', - ('uint64', 'uint32'): 'uint64', - ('uint64', 'uint64'): 'uint64', + (xp.uint8, xp.uint8): xp.uint8, + (xp.uint8, xp.uint16): xp.uint16, + (xp.uint8, xp.uint32): xp.uint32, + (xp.uint8, xp.uint64): xp.uint64, + (xp.uint16, xp.uint8): xp.uint16, + (xp.uint16, xp.uint16): xp.uint16, + (xp.uint16, xp.uint32): xp.uint32, + (xp.uint16, xp.uint64): xp.uint64, + (xp.uint32, xp.uint8): xp.uint32, + (xp.uint32, xp.uint16): xp.uint32, + (xp.uint32, xp.uint32): xp.uint32, + (xp.uint32, xp.uint64): xp.uint64, + (xp.uint64, xp.uint8): xp.uint64, + (xp.uint64, xp.uint16): xp.uint64, + (xp.uint64, xp.uint32): xp.uint64, + (xp.uint64, xp.uint64): xp.uint64, } + mixed_signed_unsigned_promotion_table = { - ('int8', 'uint8'): 'int16', - ('int8', 'uint16'): 'int32', - ('int8', 'uint32'): 'int64', - ('int16', 'uint8'): 'int16', - ('int16', 'uint16'): 'int32', - ('int16', 'uint32'): 'int64', - ('int32', 'uint8'): 'int32', - ('int32', 'uint16'): 'int32', - ('int32', 'uint32'): 'int64', - ('int64', 'uint8'): 'int64', - ('int64', 'uint16'): 'int64', - ('int64', 'uint32'): 'int64', + (xp.int8, xp.uint8): xp.int16, + (xp.int8, xp.uint16): xp.int32, + (xp.int8, xp.uint32): xp.int64, + (xp.int16, xp.uint8): xp.int16, + (xp.int16, xp.uint16): xp.int32, + (xp.int16, xp.uint32): xp.int64, + (xp.int32, xp.uint8): xp.int32, + (xp.int32, xp.uint16): xp.int32, + (xp.int32, xp.uint32): xp.int64, + (xp.int64, xp.uint8): xp.int64, + (xp.int64, xp.uint16): xp.int64, + (xp.int64, xp.uint32): xp.int64, } + flipped_mixed_signed_unsigned_promotion_table = {(u, i): p for (i, u), p in mixed_signed_unsigned_promotion_table.items()} + float_promotion_table = { - ('float32', 'float32'): 'float32', - ('float32', 'float64'): 'float64', - ('float64', 'float32'): 'float64', - ('float64', 'float64'): 'float64', + (xp.float32, xp.float32): xp.float32, + (xp.float32, xp.float64): xp.float64, + (xp.float64, xp.float32): xp.float64, + (xp.float64, xp.float64): xp.float64, } + boolean_promotion_table = { - ('bool', 'bool'): 'bool', + (xp.bool, xp.bool): xp.bool, } + promotion_table = { **signed_integer_promotion_table, **unsigned_integer_promotion_table, @@ -136,6 +129,7 @@ def dtype_signed(dtype): **boolean_promotion_table, } + input_types = { 'any': sorted(set(promotion_table.values())), 'boolean': sorted(set(boolean_promotion_table.values())), @@ -150,21 +144,23 @@ def dtype_signed(dtype): **unsigned_integer_promotion_table}.values())), } + dtypes_to_scalars = { - 'bool': [bool], - 'int8': [int], - 'int16': [int], - 'int32': [int], - 'int64': [int], + xp.bool: [bool], + xp.int8: [int], + xp.int16: [int], + xp.int32: [int], + xp.int64: [int], # Note: unsigned int dtypes only correspond to positive integers - 'uint8': [int], - 'uint16': [int], - 'uint32': [int], - 'uint64': [int], - 'float32': [int, float], - 'float64': [int, float], + xp.uint8: [int], + xp.uint16: [int], + xp.uint32: [int], + xp.uint64: [int], + xp.float32: [int, float], + xp.float64: [int, float], } + elementwise_function_input_types = { 'abs': 'numeric', 'acos': 'floating', @@ -224,6 +220,7 @@ def dtype_signed(dtype): 'trunc': 'numeric', } + elementwise_function_output_types = { 'abs': 'promoted', 'acos': 'promoted', @@ -283,6 +280,7 @@ def dtype_signed(dtype): 'trunc': 'promoted', } + binary_operators = { '__add__': '+', '__and__': '&', @@ -305,6 +303,7 @@ def dtype_signed(dtype): '__xor__': '^', } + unary_operators = { '__abs__': 'abs()', '__invert__': '~', diff --git a/array_api_tests/hypothesis_helpers.py b/array_api_tests/hypothesis_helpers.py index 71e2859c..0be6f8d8 100644 --- a/array_api_tests/hypothesis_helpers.py +++ b/array_api_tests/hypothesis_helpers.py @@ -14,7 +14,7 @@ boolean_dtype_objects, integer_or_boolean_dtype_objects, dtype_objects) from ._array_module import full, float32, float64, bool as bool_dtype, _UndefinedStub -from .dtype_helpers import dtype_mapping, promotion_table +from .dtype_helpers import promotion_table from . import _array_module from . import _array_module as xp @@ -54,9 +54,8 @@ sorted_table = sorted( sorted_table, key=lambda ij: -1 if ij[0] == ij[1] else sorted_table.index(ij) ) -dtype_pairs = [(dtype_mapping[i], dtype_mapping[j]) for i, j in sorted_table] if FILTER_UNDEFINED_DTYPES: - dtype_pairs = [(i, j) for i, j in dtype_pairs + sorted_table = [(i, j) for i, j in sorted_table if not isinstance(i, _UndefinedStub) and not isinstance(j, _UndefinedStub)] @@ -70,7 +69,7 @@ def mutually_promotable_dtypes(dtype_objects=dtype_objects): # pairs (XXX: Can we redesign the strategies so that they can prefer # shrinking dtypes over values?) return sampled_from( - [(i, j) for i, j in dtype_pairs if i in dtype_objects and j in dtype_objects] + [(i, j) for i, j in sorted_table if i in dtype_objects and j in dtype_objects] ) shared_mutually_promotable_dtype_pairs = shared( diff --git a/array_api_tests/test_broadcasting.py b/array_api_tests/test_broadcasting.py index 69209075..f310e209 100644 --- a/array_api_tests/test_broadcasting.py +++ b/array_api_tests/test_broadcasting.py @@ -10,8 +10,7 @@ from .hypothesis_helpers import shapes, FILTER_UNDEFINED_DTYPES from .pytest_helpers import raises, doesnt_raise, nargs -from .dtype_helpers import (elementwise_function_input_types, - input_types, dtype_mapping) +from .dtype_helpers import elementwise_function_input_types, input_types from .function_stubs import elementwise_functions from . import _array_module from ._array_module import ones, _UndefinedStub @@ -111,14 +110,14 @@ def test_broadcast_shapes_explicit_spec(): @pytest.mark.parametrize('func_name', [i for i in elementwise_functions.__all__ if nargs(i) > 1]) -@given(shape1=shapes, shape2=shapes, dtype=data()) -def test_broadcasting_hypothesis(func_name, shape1, shape2, dtype): +@given(shape1=shapes, shape2=shapes, data=data()) +def test_broadcasting_hypothesis(func_name, shape1, shape2, data): # Internal consistency checks assert nargs(func_name) == 2 - dtype = dtype_mapping[dtype.draw(sampled_from(input_types[elementwise_function_input_types[func_name]]))] - if FILTER_UNDEFINED_DTYPES and isinstance(dtype, _UndefinedStub): - assume(False) + dtype = data.draw(sampled_from(input_types[elementwise_function_input_types[func_name]])) + if FILTER_UNDEFINED_DTYPES: + assume(not isinstance(dtype, _UndefinedStub)) func = getattr(_array_module, func_name) if isinstance(func, _array_module._UndefinedStub): diff --git a/array_api_tests/test_type_promotion.py b/array_api_tests/test_type_promotion.py index 98e1850f..443528c9 100644 --- a/array_api_tests/test_type_promotion.py +++ b/array_api_tests/test_type_promotion.py @@ -16,7 +16,6 @@ from ._array_module import (full, bool as bool_dtype) from . import _array_module from .dtype_helpers import ( - dtype_mapping, promotion_table, input_types, dtypes_to_scalars, @@ -28,7 +27,6 @@ ) - elementwise_function_two_arg_func_names = [func_name for func_name in elementwise_functions.__all__ if nargs(func_name) > 1] @@ -46,7 +44,7 @@ for d in dtypes) ] -elementwise_function_two_arg_bool_parametrize_ids = ['-'.join((n, d1, d2)) for n, (d1, d2) +elementwise_function_two_arg_bool_parametrize_ids = [f"{n}-{d1}-{d2}" for n, (d1, d2) in elementwise_function_two_arg_bool_parametrize_inputs] # TODO: These functions should still do type promotion internally, but we do @@ -70,9 +68,7 @@ def test_elementwise_function_two_arg_bool_type_promotion(func_name, assert nargs(func_name) == 2 func = getattr(_array_module, func_name) - type1, type2 = dtypes - dtype1 = dtype_mapping[type1] - dtype2 = dtype_mapping[type2] + dtype1, dtype2 = dtypes fillvalue1 = fillvalues.draw(scalars(just(dtype1))) if func_name in ['bitwise_left_shift', 'bitwise_right_shift']: @@ -105,7 +101,7 @@ def test_elementwise_function_two_arg_bool_type_promotion(func_name, for d in dtypes[0]) ] -elementwise_function_two_arg_promoted_parametrize_ids = ['-'.join((n, d1, d2)) for n, ((d1, d2), _) +elementwise_function_two_arg_promoted_parametrize_ids = [f"{n}-{d1}-{d2}" for n, ((d1, d2), _) in elementwise_function_two_arg_promoted_parametrize_inputs] # TODO: Extend this to all functions (not just elementwise), and handle @@ -124,10 +120,7 @@ def test_elementwise_function_two_arg_promoted_type_promotion(func_name, assert nargs(func_name) == 2 func = getattr(_array_module, func_name) - (type1, type2), res_type = dtypes - dtype1 = dtype_mapping[type1] - dtype2 = dtype_mapping[type2] - res_dtype = dtype_mapping[res_type] + (dtype1, dtype2), res_dtype = dtypes fillvalue1 = fillvalues.draw(scalars(just(dtype1))) if func_name in ['bitwise_left_shift', 'bitwise_right_shift']: fillvalue2 = fillvalues.draw(scalars(just(dtype2)).filter(lambda x: x > 0)) @@ -159,12 +152,12 @@ def test_elementwise_function_two_arg_promoted_type_promotion(func_name, elementwise_function_one_arg_bool_parametrize_inputs = [(func_name, dtypes) for func_name in elementwise_function_one_arg_func_names_bool for dtypes in input_types[elementwise_function_input_types[func_name]]] -elementwise_function_one_arg_bool_parametrize_ids = ['-'.join((n, d)) for n, d +elementwise_function_one_arg_bool_parametrize_ids = [f"{n}-{d}" for n, d in elementwise_function_one_arg_bool_parametrize_inputs] # TODO: Extend this to all functions (not just elementwise), and handle # functions that take more than 2 args -@pytest.mark.parametrize('func_name,dtype_name', +@pytest.mark.parametrize('func_name,dtype', elementwise_function_one_arg_bool_parametrize_inputs, ids=elementwise_function_one_arg_bool_parametrize_ids) # The spec explicitly requires type promotion to work for shape 0 @@ -173,11 +166,10 @@ def test_elementwise_function_two_arg_promoted_type_promotion(func_name, # @example(shape=(0,)) @given(shape=shapes, fillvalues=data()) def test_elementwise_function_one_arg_bool(func_name, shape, - dtype_name, fillvalues): + dtype, fillvalues): assert nargs(func_name) == 1 func = getattr(_array_module, func_name) - dtype = dtype_mapping[dtype_name] fillvalue = fillvalues.draw(scalars(just(dtype))) for i in [func, dtype]: @@ -198,12 +190,12 @@ def test_elementwise_function_one_arg_bool(func_name, shape, elementwise_function_one_arg_promoted_parametrize_inputs = [(func_name, dtypes) for func_name in elementwise_function_one_arg_func_names_promoted for dtypes in input_types[elementwise_function_input_types[func_name]]] -elementwise_function_one_arg_promoted_parametrize_ids = ['-'.join((n, d)) for n, d +elementwise_function_one_arg_promoted_parametrize_ids = [f"{n}-{d}" for n, d in elementwise_function_one_arg_promoted_parametrize_inputs] # TODO: Extend this to all functions (not just elementwise), and handle # functions that take more than 2 args -@pytest.mark.parametrize('func_name,dtype_name', +@pytest.mark.parametrize('func_name,dtype', elementwise_function_one_arg_promoted_parametrize_inputs, ids=elementwise_function_one_arg_promoted_parametrize_ids) # The spec explicitly requires type promotion to work for shape 0 @@ -212,11 +204,10 @@ def test_elementwise_function_one_arg_bool(func_name, shape, # @example(shape=(0,)) @given(shape=shapes, fillvalues=data()) def test_elementwise_function_one_arg_type_promotion(func_name, shape, - dtype_name, fillvalues): + dtype, fillvalues): assert nargs(func_name) == 1 func = getattr(_array_module, func_name) - dtype = dtype_mapping[dtype_name] fillvalue = fillvalues.draw(scalars(just(dtype))) for i in [func, dtype]: @@ -234,13 +225,13 @@ def test_elementwise_function_one_arg_type_promotion(func_name, shape, for unary_op_name in unary_operators_promoted for dtypes in input_types[elementwise_function_input_types[operators_to_functions[unary_op_name]]] ] -operator_one_arg_promoted_parametrize_ids = ['-'.join((n, d)) for n, d +operator_one_arg_promoted_parametrize_ids = [f"{n}-{d}" for n, d in operator_one_arg_promoted_parametrize_inputs] # TODO: Extend this to all functions (not just elementwise), and handle # functions that take more than 2 args -@pytest.mark.parametrize('unary_op_name,dtype_name', +@pytest.mark.parametrize('unary_op_name,dtype', operator_one_arg_promoted_parametrize_inputs, ids=operator_one_arg_promoted_parametrize_ids) # The spec explicitly requires type promotion to work for shape 0 @@ -248,10 +239,9 @@ def test_elementwise_function_one_arg_type_promotion(func_name, shape, # out for now. # @example(shape=(0,)) @given(shape=shapes, fillvalues=data()) -def test_operator_one_arg_type_promotion(unary_op_name, shape, dtype_name, fillvalues): +def test_operator_one_arg_type_promotion(unary_op_name, shape, dtype, fillvalues): unary_op = unary_operators[unary_op_name] - dtype = dtype_mapping[dtype_name] fillvalue = fillvalues.draw(scalars(just(dtype))) if isinstance(dtype, _array_module._UndefinedStub): @@ -278,7 +268,7 @@ def test_operator_one_arg_type_promotion(unary_op_name, shape, dtype_name, fillv for dtypes in promotion_table.keys() if all(d in input_types[elementwise_function_input_types[operators_to_functions[binary_op_name]]] for d in dtypes) ] -operator_two_arg_bool_parametrize_ids = ['-'.join((n, d1, d2)) for n, (d1, d2) +operator_two_arg_bool_parametrize_ids = [f"{n}-{d1}-{d2}" for n, (d1, d2) in operator_two_arg_bool_parametrize_inputs] @pytest.mark.parametrize('binary_op_name,dtypes', @@ -289,9 +279,7 @@ def test_operator_two_arg_bool_promotion(binary_op_name, dtypes, two_shapes, fillvalues): binary_op = binary_operators[binary_op_name] - type1, type2 = dtypes - dtype1 = dtype_mapping[type1] - dtype2 = dtype_mapping[type2] + dtype1, dtype2 = dtypes fillvalue1 = fillvalues.draw(scalars(just(dtype1))) fillvalue2 = fillvalues.draw(scalars(just(dtype2))) @@ -316,7 +304,7 @@ def test_operator_two_arg_bool_promotion(binary_op_name, dtypes, two_shapes, for dtypes in promotion_table.items() if all(d in input_types[elementwise_function_input_types[operators_to_functions[binary_op_name]]] for d in dtypes[0]) ] -operator_two_arg_promoted_parametrize_ids = ['-'.join((n, d1, d2)) for n, ((d1, d2), _) +operator_two_arg_promoted_parametrize_ids = [f"{n}-{d1}-{d2}" for n, ((d1, d2), _) in operator_two_arg_promoted_parametrize_inputs] @pytest.mark.parametrize('binary_op_name,dtypes', @@ -327,10 +315,7 @@ def test_operator_two_arg_promoted_promotion(binary_op_name, dtypes, two_shapes, fillvalues): binary_op = binary_operators[binary_op_name] - (type1, type2), res_type = dtypes - dtype1 = dtype_mapping[type1] - dtype2 = dtype_mapping[type2] - res_dtype = dtype_mapping[res_type] + (dtype1, dtype2), res_dtype = dtypes fillvalue1 = fillvalues.draw(scalars(just(dtype1))) if binary_op_name in ['>>', '<<']: fillvalue2 = fillvalues.draw(scalars(just(dtype2)).filter(lambda x: x > 0)) @@ -354,7 +339,7 @@ def test_operator_two_arg_promoted_promotion(binary_op_name, dtypes, two_shapes, operator_inplace_two_arg_promoted_parametrize_inputs = [(binary_op, dtypes) for binary_op, dtypes in operator_two_arg_promoted_parametrize_inputs if dtypes[0][0] == dtypes[1]] -operator_inplace_two_arg_promoted_parametrize_ids = ['-'.join((n[:2] + 'i' + n[2:], d1, d2)) for n, ((d1, d2), _) +operator_inplace_two_arg_promoted_parametrize_ids = ['-'.join((n[:2] + 'i' + n[2:], str(d1), str(d2))) for n, ((d1, d2), _) in operator_inplace_two_arg_promoted_parametrize_inputs] @pytest.mark.parametrize('binary_op_name,dtypes', @@ -365,10 +350,7 @@ def test_operator_inplace_two_arg_promoted_promotion(binary_op_name, dtypes, two fillvalues): binary_op = binary_operators[binary_op_name] - (type1, type2), res_type = dtypes - dtype1 = dtype_mapping[type1] - dtype2 = dtype_mapping[type2] - res_dtype = dtype_mapping[res_type] + (dtype1, dtype2), res_dtype = dtypes fillvalue1 = fillvalues.draw(scalars(just(dtype1))) if binary_op_name in ['>>', '<<']: fillvalue2 = fillvalues.draw(scalars(just(dtype2)).filter(lambda x: x > 0)) @@ -392,15 +374,15 @@ def test_operator_inplace_two_arg_promoted_promotion(binary_op_name, dtypes, two assert res.dtype == res_dtype, f"{dtype1} {binary_op}= {dtype2} promoted to {res.dtype}, should have promoted to {res_dtype} (shape={shape1, shape2})" -scalar_promotion_parametrize_inputs = [(binary_op_name, dtype_name, scalar_type) +scalar_promotion_parametrize_inputs = [(binary_op_name, dtype, scalar_type) for binary_op_name in sorted(set(binary_operators) - {'__matmul__'}) - for dtype_name in input_types[elementwise_function_input_types[operators_to_functions[binary_op_name]]] - for scalar_type in dtypes_to_scalars[dtype_name]] + for dtype in input_types[elementwise_function_input_types[operators_to_functions[binary_op_name]]] + for scalar_type in dtypes_to_scalars[dtype]] -@pytest.mark.parametrize('binary_op_name,dtype_name,scalar_type', +@pytest.mark.parametrize('binary_op_name,dtype,scalar_type', scalar_promotion_parametrize_inputs) @given(shape=shapes, python_scalars=data(), fillvalues=data()) -def test_operator_scalar_promotion(binary_op_name, dtype_name, scalar_type, +def test_operator_scalar_promotion(binary_op_name, dtype, scalar_type, shape, python_scalars, fillvalues): """ See https://data-apis.github.io/array-api/latest/API_specification/type_promotion.html#mixing-arrays-with-python-scalars @@ -408,9 +390,8 @@ def test_operator_scalar_promotion(binary_op_name, dtype_name, scalar_type, binary_op = binary_operators[binary_op_name] if binary_op == '@': pytest.skip("matmul (@) is not supported for scalars") - dtype = dtype_mapping[dtype_name] - if dtype_name in input_types['integer']: + if dtype in input_types['integer']: s = python_scalars.draw(integers(*dtype_ranges[dtype])) else: s = python_scalars.draw(from_type(scalar_type)) From e4461b33e149b5e270f680e433ebd8643079e79b Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Thu, 7 Oct 2021 12:47:46 +0100 Subject: [PATCH 05/41] Improve parametrize names for scalar promotion tests --- array_api_tests/test_type_promotion.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/array_api_tests/test_type_promotion.py b/array_api_tests/test_type_promotion.py index 443528c9..37e9d9ba 100644 --- a/array_api_tests/test_type_promotion.py +++ b/array_api_tests/test_type_promotion.py @@ -374,10 +374,12 @@ def test_operator_inplace_two_arg_promoted_promotion(binary_op_name, dtypes, two assert res.dtype == res_dtype, f"{dtype1} {binary_op}= {dtype2} promoted to {res.dtype}, should have promoted to {res_dtype} (shape={shape1, shape2})" -scalar_promotion_parametrize_inputs = [(binary_op_name, dtype, scalar_type) - for binary_op_name in sorted(set(binary_operators) - {'__matmul__'}) - for dtype in input_types[elementwise_function_input_types[operators_to_functions[binary_op_name]]] - for scalar_type in dtypes_to_scalars[dtype]] +scalar_promotion_parametrize_inputs = [ + pytest.param(binary_op_name, dtype, scalar_type, id=f"{binary_op_name}-{dtype}-{scalar_type.__name__}") + for binary_op_name in sorted(set(binary_operators) - {'__matmul__'}) + for dtype in input_types[elementwise_function_input_types[operators_to_functions[binary_op_name]]] + for scalar_type in dtypes_to_scalars[dtype] +] @pytest.mark.parametrize('binary_op_name,dtype,scalar_type', scalar_promotion_parametrize_inputs) From 48fd544e2b7afa49def01e1f5893faec44d5a3e5 Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Thu, 7 Oct 2021 16:47:13 +0100 Subject: [PATCH 06/41] Convert `dtype_signed` and `dtype_nbits` to dicts --- array_api_tests/dtype_helpers.py | 43 ++++++------------- .../meta_tests/test_array_helpers.py | 4 +- array_api_tests/test_elementwise_functions.py | 14 +++--- 3 files changed, 21 insertions(+), 40 deletions(-) diff --git a/array_api_tests/dtype_helpers.py b/array_api_tests/dtype_helpers.py index b955fa6e..43091a88 100644 --- a/array_api_tests/dtype_helpers.py +++ b/array_api_tests/dtype_helpers.py @@ -15,37 +15,18 @@ ] -def dtype_nbits(dtype): - if dtype == xp.int8: - return 8 - elif dtype == xp.int16: - return 16 - elif dtype == xp.int32: - return 32 - elif dtype == xp.int64: - return 64 - elif dtype == xp.uint8: - return 8 - elif dtype == xp.uint16: - return 16 - elif dtype == xp.uint32: - return 32 - elif dtype == xp.uint64: - return 64 - elif dtype == xp.float32: - return 32 - elif dtype == xp.float64: - return 64 - else: - raise ValueError(f"dtype_nbits is not defined for {dtype}") - - -def dtype_signed(dtype): - if dtype in [xp.int8, xp.int16, xp.int32, xp.int64]: - return True - elif dtype in [xp.uint8, xp.uint16, xp.uint32, xp.uint64]: - return False - raise ValueError("dtype_signed is only defined for integer dtypes") +dtype_nbits = { + **{d: 8 for d in [xp.int8, xp.uint8]}, + **{d: 16 for d in [xp.int16, xp.uint16]}, + **{d: 32 for d in [xp.int32, xp.uint32, xp.float32]}, + **{d: 64 for d in [xp.int64, xp.uint64, xp.float64]}, +} + + +dtype_signed = { + **{d: True for d in [xp.int8, xp.int16, xp.int32, xp.int64]}, + **{d: False for d in [xp.uint8, xp.uint16, xp.uint32, xp.uint64]}, +} signed_integer_promotion_table = { diff --git a/array_api_tests/meta_tests/test_array_helpers.py b/array_api_tests/meta_tests/test_array_helpers.py index 7858183d..4f5b016a 100644 --- a/array_api_tests/meta_tests/test_array_helpers.py +++ b/array_api_tests/meta_tests/test_array_helpers.py @@ -25,8 +25,8 @@ def test_notequal(): @given(integers(), integer_dtypes) def test_int_to_dtype(x, dtype): - n = dtype_nbits(dtype) - signed = dtype_signed(dtype) + n = dtype_nbits[dtype] + signed = dtype_signed[dtype] try: d = xp.asarray(x, dtype=dtype) except OverflowError: diff --git a/array_api_tests/test_elementwise_functions.py b/array_api_tests/test_elementwise_functions.py index 4bae2b29..a0e59191 100644 --- a/array_api_tests/test_elementwise_functions.py +++ b/array_api_tests/test_elementwise_functions.py @@ -223,7 +223,7 @@ def test_bitwise_and(args): x = int(x1) y = int(x2) res = int(a) - ans = int_to_dtype(x & y, dtype_nbits(a.dtype), dtype_signed(a.dtype)) + ans = int_to_dtype(x & y, dtype_nbits[a.dtype], dtype_signed[a.dtype]) assert ans == res @given(two_integer_dtypes.flatmap(lambda i: two_array_scalars(*i))) @@ -240,12 +240,12 @@ def test_bitwise_left_shift(args): raise RuntimeError("Error: test_bitwise_left_shift needs to be updated for nonscalar array inputs") x = int(x1) y = int(x2) - if y >= dtype_nbits(a.dtype): + if y >= dtype_nbits[a.dtype]: # Avoid shifting very large y in Python ints ans = 0 else: ans = x << y - ans = int_to_dtype(ans, dtype_nbits(a.dtype), dtype_signed(a.dtype)) + ans = int_to_dtype(ans, dtype_nbits[a.dtype], dtype_signed[a.dtype]) res = int(a) assert ans == res @@ -263,7 +263,7 @@ def test_bitwise_invert(x): else: x = int(x) res = int(a) - ans = int_to_dtype(~x, dtype_nbits(a.dtype), dtype_signed(a.dtype)) + ans = int_to_dtype(~x, dtype_nbits[a.dtype], dtype_signed[a.dtype]) assert ans == res @given(two_integer_or_boolean_dtypes.flatmap(lambda i: two_array_scalars(*i))) @@ -284,7 +284,7 @@ def test_bitwise_or(args): x = int(x1) y = int(x2) res = int(a) - ans = int_to_dtype(x | y, dtype_nbits(a.dtype), dtype_signed(a.dtype)) + ans = int_to_dtype(x | y, dtype_nbits[a.dtype], dtype_signed[a.dtype]) assert ans == res @given(two_integer_dtypes.flatmap(lambda i: two_array_scalars(*i))) @@ -301,7 +301,7 @@ def test_bitwise_right_shift(args): raise RuntimeError("Error: test_bitwise_right_shift needs to be updated for nonscalar array inputs") x = int(x1) y = int(x2) - ans = int_to_dtype(x >> y, dtype_nbits(a.dtype), dtype_signed(a.dtype)) + ans = int_to_dtype(x >> y, dtype_nbits[a.dtype], dtype_signed[a.dtype]) res = int(a) assert ans == res @@ -323,7 +323,7 @@ def test_bitwise_xor(args): x = int(x1) y = int(x2) res = int(a) - ans = int_to_dtype(x ^ y, dtype_nbits(a.dtype), dtype_signed(a.dtype)) + ans = int_to_dtype(x ^ y, dtype_nbits[a.dtype], dtype_signed[a.dtype]) assert ans == res @given(numeric_scalars) From 8c5fad0c5814acbbebb8ad5b9c9c5c1d27f3217a Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Fri, 8 Oct 2021 09:10:03 +0100 Subject: [PATCH 07/41] Use `dh.promotion_table` instead of `ah.promote_dtypes()` --- array_api_tests/array_helpers.py | 16 +--------------- .../meta_tests/test_array_helpers.py | 19 +------------------ array_api_tests/test_elementwise_functions.py | 18 +++++++++--------- 3 files changed, 11 insertions(+), 42 deletions(-) diff --git a/array_api_tests/array_helpers.py b/array_api_tests/array_helpers.py index 1c333a92..b5f5a06f 100644 --- a/array_api_tests/array_helpers.py +++ b/array_api_tests/array_helpers.py @@ -9,7 +9,6 @@ _numeric_dtypes, _boolean_dtypes, _dtypes, asarray) from . import _array_module -from .dtype_helpers import promotion_table # These are exported here so that they can be included in the special cases # tests from this file. @@ -27,7 +26,7 @@ 'assert_isinf', 'positive_mathematical_sign', 'assert_positive_mathematical_sign', 'negative_mathematical_sign', 'assert_negative_mathematical_sign', 'same_sign', - 'assert_same_sign', 'ndindex', 'promote_dtypes', 'float64', + 'assert_same_sign', 'ndindex', 'float64', 'asarray', 'is_integer_dtype', 'is_float_dtype', 'dtype_ranges', 'full', 'true', 'false', 'isnan'] @@ -365,16 +364,3 @@ def ndindex(shape): """ return itertools.product(*[range(i) for i in shape]) - -def promote_dtypes(dtype1, dtype2): - """ - Special case of result_type() which uses the exact type promotion table - from the spec. - """ - try: - return promotion_table[(dtype1, dtype2)] - except KeyError as e: - raise ValueError( - f"{dtype1} and {dtype2} are not type promotable according to the spec" - f"(this may indicate a bug in the test suite)." - ) from e diff --git a/array_api_tests/meta_tests/test_array_helpers.py b/array_api_tests/meta_tests/test_array_helpers.py index 4f5b016a..1a2ae832 100644 --- a/array_api_tests/meta_tests/test_array_helpers.py +++ b/array_api_tests/meta_tests/test_array_helpers.py @@ -2,7 +2,7 @@ from hypothesis import given, assume from hypothesis.strategies import integers -from ..array_helpers import exactly_equal, notequal, int_to_dtype, promote_dtypes +from ..array_helpers import exactly_equal, notequal, int_to_dtype from ..hypothesis_helpers import integer_dtypes from ..dtype_helpers import dtype_nbits, dtype_signed from .. import _array_module as xp @@ -32,20 +32,3 @@ def test_int_to_dtype(x, dtype): except OverflowError: assume(False) assert int_to_dtype(x, n, signed) == d - -@pytest.mark.parametrize( - "dtype1, dtype2, result", - [ - (xp.uint8, xp.uint8, xp.uint8), - (xp.uint8, xp.int8, xp.int16), - (xp.int8, xp.int8, xp.int8), - ] -) -def test_promote_dtypes(dtype1, dtype2, result): - assert promote_dtypes(dtype1, dtype2) == result - - -@pytest.mark.parametrize("dtype1, dtype2", [(xp.uint8, xp.float32)]) -def test_promote_dtypes_incompatible_dtypes_fail(dtype1, dtype2): - with pytest.raises(ValueError): - promote_dtypes(dtype1, dtype2) diff --git a/array_api_tests/test_elementwise_functions.py b/array_api_tests/test_elementwise_functions.py index a0e59191..f950add2 100644 --- a/array_api_tests/test_elementwise_functions.py +++ b/array_api_tests/test_elementwise_functions.py @@ -34,11 +34,11 @@ infinity, isnegative, all as array_all, any as array_any, int_to_dtype, bool as bool_dtype, assert_integral, less_equal, isintegral, isfinite, - ndindex, promote_dtypes, is_integer_dtype, + ndindex, is_integer_dtype, is_float_dtype, not_equal, float64, asarray, dtype_ranges, full, true, false, assert_same_sign, isnan, less) -from .dtype_helpers import dtype_nbits, dtype_signed +from .dtype_helpers import dtype_nbits, dtype_signed, promotion_table # We might as well use this implementation rather than requiring # mod.broadcast_shapes(). See test_equal() and others. from .test_broadcasting import broadcast_shapes @@ -66,7 +66,7 @@ def two_array_scalars(draw, dtype1, dtype2): def sanity_check(x1, x2): try: - promote_dtypes(x1.dtype, x2.dtype) + promotion_table[x1.dtype, x2.dtype] except ValueError: raise RuntimeError("Error in test generation (probably a bug in the test suite") @@ -400,7 +400,7 @@ def test_equal(x1, x2): # test_type_promotion.py. The type promotion for equal() is not *really* # tested in that file, because doing so requires doing the consistency # check we do here rather than just checking the result dtype. - promoted_dtype = promote_dtypes(x1.dtype, x2.dtype) + promoted_dtype = promotion_table[x1.dtype, x2.dtype] _x1 = _array_module.asarray(_x1, dtype=promoted_dtype) _x2 = _array_module.asarray(_x2, dtype=promoted_dtype) @@ -487,7 +487,7 @@ def test_greater(args): _x1 = _array_module.broadcast_to(x1, shape) _x2 = _array_module.broadcast_to(x2, shape) - promoted_dtype = promote_dtypes(x1.dtype, x2.dtype) + promoted_dtype = promotion_table[x1.dtype, x2.dtype] _x1 = _array_module.asarray(_x1, dtype=promoted_dtype) _x2 = _array_module.asarray(_x2, dtype=promoted_dtype) @@ -517,7 +517,7 @@ def test_greater_equal(args): _x1 = _array_module.broadcast_to(x1, shape) _x2 = _array_module.broadcast_to(x2, shape) - promoted_dtype = promote_dtypes(x1.dtype, x2.dtype) + promoted_dtype = promotion_table[x1.dtype, x2.dtype] _x1 = _array_module.asarray(_x1, dtype=promoted_dtype) _x2 = _array_module.asarray(_x2, dtype=promoted_dtype) @@ -593,7 +593,7 @@ def test_less(args): _x1 = _array_module.broadcast_to(x1, shape) _x2 = _array_module.broadcast_to(x2, shape) - promoted_dtype = promote_dtypes(x1.dtype, x2.dtype) + promoted_dtype = promotion_table[x1.dtype, x2.dtype] _x1 = _array_module.asarray(_x1, dtype=promoted_dtype) _x2 = _array_module.asarray(_x2, dtype=promoted_dtype) @@ -623,7 +623,7 @@ def test_less_equal(args): _x1 = _array_module.broadcast_to(x1, shape) _x2 = _array_module.broadcast_to(x2, shape) - promoted_dtype = promote_dtypes(x1.dtype, x2.dtype) + promoted_dtype = promotion_table[x1.dtype, x2.dtype] _x1 = _array_module.asarray(_x1, dtype=promoted_dtype) _x2 = _array_module.asarray(_x2, dtype=promoted_dtype) @@ -786,7 +786,7 @@ def test_not_equal(args): _x1 = _array_module.broadcast_to(x1, shape) _x2 = _array_module.broadcast_to(x2, shape) - promoted_dtype = promote_dtypes(x1.dtype, x2.dtype) + promoted_dtype = promotion_table[x1.dtype, x2.dtype] _x1 = _array_module.asarray(_x1, dtype=promoted_dtype) _x2 = _array_module.asarray(_x2, dtype=promoted_dtype) From c9d958e2dd57806919a1007b9273bd8f2b6f7aeb Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Fri, 8 Oct 2021 10:49:43 +0100 Subject: [PATCH 08/41] Refactor `dtype_helpers` --- array_api_tests/dtype_helpers.py | 130 +++++++++++-------------------- 1 file changed, 45 insertions(+), 85 deletions(-) diff --git a/array_api_tests/dtype_helpers.py b/array_api_tests/dtype_helpers.py index 43091a88..6b6b383d 100644 --- a/array_api_tests/dtype_helpers.py +++ b/array_api_tests/dtype_helpers.py @@ -2,74 +2,68 @@ __all__ = [ - "promotion_table", - "dtype_nbits", - "dtype_signed", - "input_types", - "dtypes_to_scalars", - "elementwise_function_input_types", - "elementwise_function_output_types", - "binary_operators", - "unary_operators", - "operators_to_functions", + 'dtypes_to_scalars', + 'input_types', + 'promotion_table', + 'dtype_nbits', + 'dtype_signed', + 'binary_operators', + 'unary_operators', + 'operators_to_functions', + 'elementwise_function_input_types', + 'elementwise_function_output_types', ] -dtype_nbits = { - **{d: 8 for d in [xp.int8, xp.uint8]}, - **{d: 16 for d in [xp.int16, xp.uint16]}, - **{d: 32 for d in [xp.int32, xp.uint32, xp.float32]}, - **{d: 64 for d in [xp.int64, xp.uint64, xp.float64]}, +int_dtypes = (xp.int8, xp.int16, xp.int32, xp.int64) +uint_dtypes = (xp.uint8, xp.uint16, xp.uint32, xp.uint64) +all_int_dtypes = int_dtypes + uint_dtypes +float_dtypes = (xp.float32, xp.float64) +numeric_dtypes = all_int_dtypes + float_dtypes +all_dtypes = (xp.bool,) + numeric_dtypes + + +dtypes_to_scalars = { + xp.bool: [bool], + **{d: [int] for d in all_int_dtypes}, + **{d: [int, float] for d in float_dtypes}, } -dtype_signed = { - **{d: True for d in [xp.int8, xp.int16, xp.int32, xp.int64]}, - **{d: False for d in [xp.uint8, xp.uint16, xp.uint32, xp.uint64]}, +input_types = { + 'any': all_dtypes, + 'boolean': (xp.bool,), + 'floating': float_dtypes, + 'integer': all_int_dtypes, + 'integer_or_boolean': (xp.bool,) + uint_dtypes + int_dtypes, + 'numeric': numeric_dtypes, } -signed_integer_promotion_table = { +_numeric_promotions = { + # ints (xp.int8, xp.int8): xp.int8, (xp.int8, xp.int16): xp.int16, (xp.int8, xp.int32): xp.int32, (xp.int8, xp.int64): xp.int64, - (xp.int16, xp.int8): xp.int16, (xp.int16, xp.int16): xp.int16, (xp.int16, xp.int32): xp.int32, (xp.int16, xp.int64): xp.int64, - (xp.int32, xp.int8): xp.int32, - (xp.int32, xp.int16): xp.int32, (xp.int32, xp.int32): xp.int32, (xp.int32, xp.int64): xp.int64, - (xp.int64, xp.int8): xp.int64, - (xp.int64, xp.int16): xp.int64, - (xp.int64, xp.int32): xp.int64, (xp.int64, xp.int64): xp.int64, -} - - -unsigned_integer_promotion_table = { + # uints (xp.uint8, xp.uint8): xp.uint8, (xp.uint8, xp.uint16): xp.uint16, (xp.uint8, xp.uint32): xp.uint32, (xp.uint8, xp.uint64): xp.uint64, - (xp.uint16, xp.uint8): xp.uint16, (xp.uint16, xp.uint16): xp.uint16, (xp.uint16, xp.uint32): xp.uint32, (xp.uint16, xp.uint64): xp.uint64, - (xp.uint32, xp.uint8): xp.uint32, - (xp.uint32, xp.uint16): xp.uint32, (xp.uint32, xp.uint32): xp.uint32, (xp.uint32, xp.uint64): xp.uint64, - (xp.uint64, xp.uint8): xp.uint64, - (xp.uint64, xp.uint16): xp.uint64, - (xp.uint64, xp.uint32): xp.uint64, (xp.uint64, xp.uint64): xp.uint64, -} - - -mixed_signed_unsigned_promotion_table = { + # ints and uints (mixed sign) (xp.int8, xp.uint8): xp.int16, (xp.int8, xp.uint16): xp.int32, (xp.int8, xp.uint32): xp.int64, @@ -82,63 +76,29 @@ (xp.int64, xp.uint8): xp.int64, (xp.int64, xp.uint16): xp.int64, (xp.int64, xp.uint32): xp.int64, -} - - -flipped_mixed_signed_unsigned_promotion_table = {(u, i): p for (i, u), p in mixed_signed_unsigned_promotion_table.items()} - - -float_promotion_table = { + # floats (xp.float32, xp.float32): xp.float32, (xp.float32, xp.float64): xp.float64, - (xp.float64, xp.float32): xp.float64, (xp.float64, xp.float64): xp.float64, } - - -boolean_promotion_table = { - (xp.bool, xp.bool): xp.bool, -} - - promotion_table = { - **signed_integer_promotion_table, - **unsigned_integer_promotion_table, - **mixed_signed_unsigned_promotion_table, - **flipped_mixed_signed_unsigned_promotion_table, - **float_promotion_table, - **boolean_promotion_table, + (xp.bool, xp.bool): xp.bool, + **_numeric_promotions, + **{(d2, d1): res for (d1, d2), res in _numeric_promotions.items()}, } -input_types = { - 'any': sorted(set(promotion_table.values())), - 'boolean': sorted(set(boolean_promotion_table.values())), - 'floating': sorted(set(float_promotion_table.values())), - 'integer': sorted(set({**signed_integer_promotion_table, - **unsigned_integer_promotion_table}.values())), - 'integer_or_boolean': sorted(set({**signed_integer_promotion_table, - **unsigned_integer_promotion_table, - **boolean_promotion_table}.values())), - 'numeric': sorted(set({**float_promotion_table, - **signed_integer_promotion_table, - **unsigned_integer_promotion_table}.values())), +dtype_nbits = { + **{d: 8 for d in [xp.int8, xp.uint8]}, + **{d: 16 for d in [xp.int16, xp.uint16]}, + **{d: 32 for d in [xp.int32, xp.uint32, xp.float32]}, + **{d: 64 for d in [xp.int64, xp.uint64, xp.float64]}, } -dtypes_to_scalars = { - xp.bool: [bool], - xp.int8: [int], - xp.int16: [int], - xp.int32: [int], - xp.int64: [int], - # Note: unsigned int dtypes only correspond to positive integers - xp.uint8: [int], - xp.uint16: [int], - xp.uint32: [int], - xp.uint64: [int], - xp.float32: [int, float], - xp.float64: [int, float], +dtype_signed = { + **{d: True for d in int_dtypes}, + **{d: False for d in uint_dtypes}, } From c07ba0e80e493c898d7a7aa7694efe84adbc9288 Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Mon, 11 Oct 2021 11:44:28 +0100 Subject: [PATCH 09/41] Rudimentary elementwise parameters refactor --- array_api_tests/dtype_helpers.py | 13 +- array_api_tests/test_broadcasting.py | 4 +- array_api_tests/test_signatures.py | 6 +- array_api_tests/test_type_promotion.py | 266 ++++++++++--------------- 4 files changed, 120 insertions(+), 169 deletions(-) diff --git a/array_api_tests/dtype_helpers.py b/array_api_tests/dtype_helpers.py index 6b6b383d..d27862cd 100644 --- a/array_api_tests/dtype_helpers.py +++ b/array_api_tests/dtype_helpers.py @@ -3,15 +3,15 @@ __all__ = [ 'dtypes_to_scalars', - 'input_types', + 'category_to_dtypes', 'promotion_table', 'dtype_nbits', 'dtype_signed', + 'func_in_categories', + 'func_out_categories', 'binary_operators', 'unary_operators', 'operators_to_functions', - 'elementwise_function_input_types', - 'elementwise_function_output_types', ] @@ -30,7 +30,7 @@ } -input_types = { +category_to_dtypes = { 'any': all_dtypes, 'boolean': (xp.bool,), 'floating': float_dtypes, @@ -84,6 +84,7 @@ promotion_table = { (xp.bool, xp.bool): xp.bool, **_numeric_promotions, + # TODO: dont unpack pairs of the same dtype **{(d2, d1): res for (d1, d2), res in _numeric_promotions.items()}, } @@ -102,7 +103,7 @@ } -elementwise_function_input_types = { +func_in_categories = { 'abs': 'numeric', 'acos': 'floating', 'acosh': 'floating', @@ -162,7 +163,7 @@ } -elementwise_function_output_types = { +func_out_categories = { 'abs': 'promoted', 'acos': 'promoted', 'acosh': 'promoted', diff --git a/array_api_tests/test_broadcasting.py b/array_api_tests/test_broadcasting.py index f310e209..c8753082 100644 --- a/array_api_tests/test_broadcasting.py +++ b/array_api_tests/test_broadcasting.py @@ -10,7 +10,7 @@ from .hypothesis_helpers import shapes, FILTER_UNDEFINED_DTYPES from .pytest_helpers import raises, doesnt_raise, nargs -from .dtype_helpers import elementwise_function_input_types, input_types +from .dtype_helpers import func_in_categories, category_to_dtypes from .function_stubs import elementwise_functions from . import _array_module from ._array_module import ones, _UndefinedStub @@ -115,7 +115,7 @@ def test_broadcasting_hypothesis(func_name, shape1, shape2, data): # Internal consistency checks assert nargs(func_name) == 2 - dtype = data.draw(sampled_from(input_types[elementwise_function_input_types[func_name]])) + dtype = data.draw(sampled_from(category_to_dtypes[func_in_categories[func_name]])) if FILTER_UNDEFINED_DTYPES: assume(not isinstance(dtype, _UndefinedStub)) func = getattr(_array_module, func_name) diff --git a/array_api_tests/test_signatures.py b/array_api_tests/test_signatures.py index 237a7446..964b1c4e 100644 --- a/array_api_tests/test_signatures.py +++ b/array_api_tests/test_signatures.py @@ -4,7 +4,7 @@ from ._array_module import mod, mod_name, ones, eye, float64, bool, int64 from .pytest_helpers import raises, doesnt_raise -from .dtype_helpers import elementwise_function_input_types, operators_to_functions +from .dtype_helpers import func_in_categories, operators_to_functions from . import function_stubs @@ -163,9 +163,9 @@ def test_function_positional_args(name): n = operators_to_functions[name[:2] + name[3:]] else: n = operators_to_functions.get(name, name) - if 'boolean' in elementwise_function_input_types.get(n, 'floating'): + if 'boolean' in func_in_categories.get(n, 'floating'): dtype = bool - elif 'integer' in elementwise_function_input_types.get(n, 'floating'): + elif 'integer' in func_in_categories.get(n, 'floating'): dtype = int64 if array_method(name): diff --git a/array_api_tests/test_type_promotion.py b/array_api_tests/test_type_promotion.py index 37e9d9ba..7eb60faf 100644 --- a/array_api_tests/test_type_promotion.py +++ b/array_api_tests/test_type_promotion.py @@ -2,6 +2,9 @@ https://data-apis.github.io/array-api/latest/API_specification/type_promotion.html """ +from itertools import product +from typing import Literal, Iterator + import pytest from hypothesis import given @@ -17,35 +20,37 @@ from . import _array_module from .dtype_helpers import ( promotion_table, - input_types, + category_to_dtypes, dtypes_to_scalars, - elementwise_function_input_types, - elementwise_function_output_types, + func_in_categories, + func_out_categories, binary_operators, unary_operators, operators_to_functions, ) -elementwise_function_two_arg_func_names = [func_name for func_name in - elementwise_functions.__all__ if - nargs(func_name) > 1] - -elementwise_function_two_arg_func_names_bool = [func_name for func_name in - elementwise_function_two_arg_func_names - if - elementwise_function_output_types[func_name] - == 'bool'] - -elementwise_function_two_arg_bool_parametrize_inputs = [(func_name, dtypes) - for func_name in elementwise_function_two_arg_func_names_bool - for dtypes in promotion_table.keys() if all(d in - input_types[elementwise_function_input_types[func_name]] - for d in dtypes) - ] +def generate_params( + in_nargs: int, + out_category: Literal['bool', 'promoted'], +) -> Iterator: + funcs = [ + f for f in elementwise_functions.__all__ + if nargs(f) == in_nargs and func_out_categories[f] == out_category + ] + if in_nargs == 1: + for func in funcs: + in_category = func_in_categories[func] + for in_dtype in category_to_dtypes[in_category]: + yield pytest.param(func, in_dtype, id=f"{func}({in_dtype})") + else: + for func, ((d1, d2), d3) in product(funcs, promotion_table.items()): + if all(d in category_to_dtypes[func_in_categories[func]] for d in (d1, d2)): + if out_category == 'bool': + yield pytest.param(func, (d1, d2), id=f"{func}({d1}, {d2})") + else: + yield pytest.param(func, ((d1, d2), d3), id=f"{func}({d1}, {d2}) -> {d3}") -elementwise_function_two_arg_bool_parametrize_ids = [f"{n}-{d1}-{d2}" for n, (d1, d2) - in elementwise_function_two_arg_bool_parametrize_inputs] # TODO: These functions should still do type promotion internally, but we do # not test this here (it is tested in the corresponding tests in @@ -54,27 +59,23 @@ # array(1.00000001, dtype=float64)) will be wrong if the float64 array is # downcast to float32. See for instance # https://github.com/numpy/numpy/issues/10322. -@pytest.mark.parametrize('func_name,dtypes', - elementwise_function_two_arg_bool_parametrize_inputs, - ids=elementwise_function_two_arg_bool_parametrize_ids) +@pytest.mark.parametrize('func, dtypes', generate_params(in_nargs=2, out_category='bool')) # The spec explicitly requires type promotion to work for shape 0 # Unfortunately, data(), isn't compatible with @example, so this is commented # out for now. # @example(shape=(0,)) -@given(two_shapes=two_mutually_broadcastable_shapes, fillvalues=data()) -def test_elementwise_function_two_arg_bool_type_promotion(func_name, - two_shapes, dtypes, - fillvalues): - assert nargs(func_name) == 2 - func = getattr(_array_module, func_name) +@given(two_shapes=two_mutually_broadcastable_shapes, data=data()) +def test_elementwise_two_args_bool_type_promotion(func, two_shapes, dtypes, data): + assert nargs(func) == 2 + func = getattr(_array_module, func) dtype1, dtype2 = dtypes - fillvalue1 = fillvalues.draw(scalars(just(dtype1))) - if func_name in ['bitwise_left_shift', 'bitwise_right_shift']: - fillvalue2 = fillvalues.draw(scalars(just(dtype2)).filter(lambda x: x > 0)) + fillvalue1 = data.draw(scalars(just(dtype1))) + if func in ['bitwise_left_shift', 'bitwise_right_shift']: + fillvalue2 = data.draw(scalars(just(dtype2)).filter(lambda x: x > 0)) else: - fillvalue2 = fillvalues.draw(scalars(just(dtype2))) + fillvalue2 = data.draw(scalars(just(dtype2))) for i in [func, dtype1, dtype2]: @@ -86,46 +87,28 @@ def test_elementwise_function_two_arg_bool_type_promotion(func_name, a2 = full(shape2, fillvalue2, dtype=dtype2) res = func(a1, a2) - assert res.dtype == bool_dtype, f"{func_name}({dtype1}, {dtype2}) promoted to {res.dtype}, should have promoted to bool (shapes={shape1, shape2})" - -elementwise_function_two_arg_func_names_promoted = [func_name for func_name in - elementwise_function_two_arg_func_names - if - elementwise_function_output_types[func_name] - == 'promoted'] - -elementwise_function_two_arg_promoted_parametrize_inputs = [(func_name, dtypes) - for func_name in elementwise_function_two_arg_func_names_promoted - for dtypes in promotion_table.items() if all(d in - input_types[elementwise_function_input_types[func_name]] - for d in dtypes[0]) - ] - -elementwise_function_two_arg_promoted_parametrize_ids = [f"{n}-{d1}-{d2}" for n, ((d1, d2), _) - in elementwise_function_two_arg_promoted_parametrize_inputs] + assert res.dtype == bool_dtype, f"{func}({dtype1}, {dtype2}) promoted to {res.dtype}, should have promoted to bool (shapes={shape1, shape2})" # TODO: Extend this to all functions (not just elementwise), and handle # functions that take more than 2 args -@pytest.mark.parametrize('func_name,dtypes', - elementwise_function_two_arg_promoted_parametrize_inputs, - ids=elementwise_function_two_arg_promoted_parametrize_ids) +@pytest.mark.parametrize('func, dtypes', generate_params(in_nargs=2, out_category='promoted')) # The spec explicitly requires type promotion to work for shape 0 # Unfortunately, data(), isn't compatible with @example, so this is commented # out for now. # @example(shape=(0,)) -@given(two_shapes=two_mutually_broadcastable_shapes, fillvalues=data()) -def test_elementwise_function_two_arg_promoted_type_promotion(func_name, +@given(two_shapes=two_mutually_broadcastable_shapes, data=data()) +def test_elementwise_two_args_promoted_type_promotion(func, two_shapes, dtypes, - fillvalues): - assert nargs(func_name) == 2 - func = getattr(_array_module, func_name) + data): + assert nargs(func) == 2 + func = getattr(_array_module, func) (dtype1, dtype2), res_dtype = dtypes - fillvalue1 = fillvalues.draw(scalars(just(dtype1))) - if func_name in ['bitwise_left_shift', 'bitwise_right_shift']: - fillvalue2 = fillvalues.draw(scalars(just(dtype2)).filter(lambda x: x > 0)) + fillvalue1 = data.draw(scalars(just(dtype1))) + if func in ['bitwise_left_shift', 'bitwise_right_shift']: + fillvalue2 = data.draw(scalars(just(dtype2)).filter(lambda x: x > 0)) else: - fillvalue2 = fillvalues.draw(scalars(just(dtype2))) + fillvalue2 = data.draw(scalars(just(dtype2))) for i in [func, dtype1, dtype2, res_dtype]: @@ -137,40 +120,21 @@ def test_elementwise_function_two_arg_promoted_type_promotion(func_name, a2 = full(shape2, fillvalue2, dtype=dtype2) res = func(a1, a2) - assert res.dtype == res_dtype, f"{func_name}({dtype1}, {dtype2}) promoted to {res.dtype}, should have promoted to {res_dtype} (shapes={shape1, shape2})" - -elementwise_function_one_arg_func_names = [func_name for func_name in - elementwise_functions.__all__ if - nargs(func_name) == 1] - -elementwise_function_one_arg_func_names_bool = [func_name for func_name in - elementwise_function_one_arg_func_names - if - elementwise_function_output_types[func_name] - == 'bool'] - -elementwise_function_one_arg_bool_parametrize_inputs = [(func_name, dtypes) - for func_name in elementwise_function_one_arg_func_names_bool - for dtypes in input_types[elementwise_function_input_types[func_name]]] -elementwise_function_one_arg_bool_parametrize_ids = [f"{n}-{d}" for n, d - in elementwise_function_one_arg_bool_parametrize_inputs] + assert res.dtype == res_dtype, f"{func}({dtype1}, {dtype2}) promoted to {res.dtype}, should have promoted to {res_dtype} (shapes={shape1, shape2})" # TODO: Extend this to all functions (not just elementwise), and handle # functions that take more than 2 args -@pytest.mark.parametrize('func_name,dtype', - elementwise_function_one_arg_bool_parametrize_inputs, - ids=elementwise_function_one_arg_bool_parametrize_ids) +@pytest.mark.parametrize('func, dtype', generate_params(in_nargs=1, out_category='bool')) # The spec explicitly requires type promotion to work for shape 0 # Unfortunately, data(), isn't compatible with @example, so this is commented # out for now. # @example(shape=(0,)) -@given(shape=shapes, fillvalues=data()) -def test_elementwise_function_one_arg_bool(func_name, shape, - dtype, fillvalues): - assert nargs(func_name) == 1 - func = getattr(_array_module, func_name) +@given(shape=shapes, data=data()) +def test_elementwise_one_arg_bool(func, shape, dtype, data): + assert nargs(func) == 1 + func = getattr(_array_module, func) - fillvalue = fillvalues.draw(scalars(just(dtype))) + fillvalue = data.draw(scalars(just(dtype))) for i in [func, dtype]: if isinstance(i, _array_module._UndefinedStub): @@ -179,36 +143,22 @@ def test_elementwise_function_one_arg_bool(func_name, shape, x = full(shape, fillvalue, dtype=dtype) res = func(x) - assert res.dtype == bool_dtype, f"{func_name}({dtype}) returned to {res.dtype}, should have promoted to bool (shape={shape})" - -elementwise_function_one_arg_func_names_promoted = [func_name for func_name in - elementwise_function_one_arg_func_names - if - elementwise_function_output_types[func_name] - == 'promoted'] - -elementwise_function_one_arg_promoted_parametrize_inputs = [(func_name, dtypes) - for func_name in elementwise_function_one_arg_func_names_promoted - for dtypes in input_types[elementwise_function_input_types[func_name]]] -elementwise_function_one_arg_promoted_parametrize_ids = [f"{n}-{d}" for n, d - in elementwise_function_one_arg_promoted_parametrize_inputs] + assert res.dtype == bool_dtype, f"{func}({dtype}) returned to {res.dtype}, should have promoted to bool (shape={shape})" # TODO: Extend this to all functions (not just elementwise), and handle # functions that take more than 2 args -@pytest.mark.parametrize('func_name,dtype', - elementwise_function_one_arg_promoted_parametrize_inputs, - ids=elementwise_function_one_arg_promoted_parametrize_ids) +@pytest.mark.parametrize('func,dtype', generate_params(in_nargs=1, out_category='promoted')) # The spec explicitly requires type promotion to work for shape 0 # Unfortunately, data(), isn't compatible with @example, so this is commented # out for now. # @example(shape=(0,)) -@given(shape=shapes, fillvalues=data()) -def test_elementwise_function_one_arg_type_promotion(func_name, shape, - dtype, fillvalues): - assert nargs(func_name) == 1 - func = getattr(_array_module, func_name) +@given(shape=shapes, data=data()) +def test_elementwise_one_arg_type_promotion(func, shape, + dtype, data): + assert nargs(func) == 1 + func = getattr(_array_module, func) - fillvalue = fillvalues.draw(scalars(just(dtype))) + fillvalue = data.draw(scalars(just(dtype))) for i in [func, dtype]: if isinstance(i, _array_module._UndefinedStub): @@ -217,13 +167,13 @@ def test_elementwise_function_one_arg_type_promotion(func_name, shape, x = full(shape, fillvalue, dtype=dtype) res = func(x) - assert res.dtype == dtype, f"{func_name}({dtype}) returned to {res.dtype}, should have promoted to {dtype} (shape={shape})" + assert res.dtype == dtype, f"{func}({dtype}) returned to {res.dtype}, should have promoted to {dtype} (shape={shape})" unary_operators_promoted = [unary_op_name for unary_op_name in sorted(unary_operators) - if elementwise_function_output_types[operators_to_functions[unary_op_name]] == 'promoted'] + if func_out_categories[operators_to_functions[unary_op_name]] == 'promoted'] operator_one_arg_promoted_parametrize_inputs = [(unary_op_name, dtypes) for unary_op_name in unary_operators_promoted - for dtypes in input_types[elementwise_function_input_types[operators_to_functions[unary_op_name]]] + for dtypes in category_to_dtypes[func_in_categories[operators_to_functions[unary_op_name]]] ] operator_one_arg_promoted_parametrize_ids = [f"{n}-{d}" for n, d in operator_one_arg_promoted_parametrize_inputs] @@ -238,11 +188,11 @@ def test_elementwise_function_one_arg_type_promotion(func_name, shape, # Unfortunately, data(), isn't compatible with @example, so this is commented # out for now. # @example(shape=(0,)) -@given(shape=shapes, fillvalues=data()) -def test_operator_one_arg_type_promotion(unary_op_name, shape, dtype, fillvalues): +@given(shape=shapes, data=data()) +def test_operator_one_arg_type_promotion(unary_op_name, shape, dtype, data): unary_op = unary_operators[unary_op_name] - fillvalue = fillvalues.draw(scalars(just(dtype))) + fillvalue = data.draw(scalars(just(dtype))) if isinstance(dtype, _array_module._UndefinedStub): dtype._raise() @@ -262,26 +212,26 @@ def test_operator_one_arg_type_promotion(unary_op_name, shape, dtype, fillvalues # Note: the boolean binary operators do not have reversed or in-place variants binary_operators_bool = [binary_op_name for binary_op_name in sorted(set(binary_operators) - {'__matmul__'}) - if elementwise_function_output_types[operators_to_functions[binary_op_name]] == 'bool'] -operator_two_arg_bool_parametrize_inputs = [(binary_op_name, dtypes) + if func_out_categories[operators_to_functions[binary_op_name]] == 'bool'] +operator_two_args_bool_parametrize_inputs = [(binary_op_name, dtypes) for binary_op_name in binary_operators_bool for dtypes in promotion_table.keys() - if all(d in input_types[elementwise_function_input_types[operators_to_functions[binary_op_name]]] for d in dtypes) + if all(d in category_to_dtypes[func_in_categories[operators_to_functions[binary_op_name]]] for d in dtypes) ] -operator_two_arg_bool_parametrize_ids = [f"{n}-{d1}-{d2}" for n, (d1, d2) - in operator_two_arg_bool_parametrize_inputs] +operator_two_args_bool_parametrize_ids = [f"{n}-{d1}-{d2}" for n, (d1, d2) + in operator_two_args_bool_parametrize_inputs] @pytest.mark.parametrize('binary_op_name,dtypes', - operator_two_arg_bool_parametrize_inputs, - ids=operator_two_arg_bool_parametrize_ids) -@given(two_shapes=two_mutually_broadcastable_shapes, fillvalues=data()) -def test_operator_two_arg_bool_promotion(binary_op_name, dtypes, two_shapes, - fillvalues): + operator_two_args_bool_parametrize_inputs, + ids=operator_two_args_bool_parametrize_ids) +@given(two_shapes=two_mutually_broadcastable_shapes, data=data()) +def test_operator_two_args_bool_promotion(binary_op_name, dtypes, two_shapes, + data): binary_op = binary_operators[binary_op_name] dtype1, dtype2 = dtypes - fillvalue1 = fillvalues.draw(scalars(just(dtype1))) - fillvalue2 = fillvalues.draw(scalars(just(dtype2))) + fillvalue1 = data.draw(scalars(just(dtype1))) + fillvalue2 = data.draw(scalars(just(dtype2))) for i in [dtype1, dtype2]: if isinstance(i, _array_module._UndefinedStub): @@ -298,29 +248,29 @@ def test_operator_two_arg_bool_promotion(binary_op_name, dtypes, two_shapes, assert res.dtype == bool_dtype, f"{dtype1} {binary_op} {dtype2} promoted to {res.dtype}, should have promoted to bool (shape={shape1, shape2})" binary_operators_promoted = [binary_op_name for binary_op_name in sorted(set(binary_operators) - {'__matmul__'}) - if elementwise_function_output_types[operators_to_functions[binary_op_name]] == 'promoted'] -operator_two_arg_promoted_parametrize_inputs = [(binary_op_name, dtypes) + if func_out_categories[operators_to_functions[binary_op_name]] == 'promoted'] +operator_two_args_promoted_parametrize_inputs = [(binary_op_name, dtypes) for binary_op_name in binary_operators_promoted for dtypes in promotion_table.items() - if all(d in input_types[elementwise_function_input_types[operators_to_functions[binary_op_name]]] for d in dtypes[0]) + if all(d in category_to_dtypes[func_in_categories[operators_to_functions[binary_op_name]]] for d in dtypes[0]) ] -operator_two_arg_promoted_parametrize_ids = [f"{n}-{d1}-{d2}" for n, ((d1, d2), _) - in operator_two_arg_promoted_parametrize_inputs] +operator_two_args_promoted_parametrize_ids = [f"{n}-{d1}-{d2}" for n, ((d1, d2), _) + in operator_two_args_promoted_parametrize_inputs] @pytest.mark.parametrize('binary_op_name,dtypes', - operator_two_arg_promoted_parametrize_inputs, - ids=operator_two_arg_promoted_parametrize_ids) -@given(two_shapes=two_mutually_broadcastable_shapes, fillvalues=data()) -def test_operator_two_arg_promoted_promotion(binary_op_name, dtypes, two_shapes, - fillvalues): + operator_two_args_promoted_parametrize_inputs, + ids=operator_two_args_promoted_parametrize_ids) +@given(two_shapes=two_mutually_broadcastable_shapes, data=data()) +def test_operator_two_args_promoted_promotion(binary_op_name, dtypes, two_shapes, + data): binary_op = binary_operators[binary_op_name] (dtype1, dtype2), res_dtype = dtypes - fillvalue1 = fillvalues.draw(scalars(just(dtype1))) + fillvalue1 = data.draw(scalars(just(dtype1))) if binary_op_name in ['>>', '<<']: - fillvalue2 = fillvalues.draw(scalars(just(dtype2)).filter(lambda x: x > 0)) + fillvalue2 = data.draw(scalars(just(dtype2)).filter(lambda x: x > 0)) else: - fillvalue2 = fillvalues.draw(scalars(just(dtype2))) + fillvalue2 = data.draw(scalars(just(dtype2))) for i in [dtype1, dtype2, res_dtype]: @@ -337,25 +287,25 @@ def test_operator_two_arg_promoted_promotion(binary_op_name, dtypes, two_shapes, assert res.dtype == res_dtype, f"{dtype1} {binary_op} {dtype2} promoted to {res.dtype}, should have promoted to {res_dtype} (shape={shape1, shape2})" -operator_inplace_two_arg_promoted_parametrize_inputs = [(binary_op, dtypes) for binary_op, dtypes in operator_two_arg_promoted_parametrize_inputs +operator_inplace_two_args_promoted_parametrize_inputs = [(binary_op, dtypes) for binary_op, dtypes in operator_two_args_promoted_parametrize_inputs if dtypes[0][0] == dtypes[1]] -operator_inplace_two_arg_promoted_parametrize_ids = ['-'.join((n[:2] + 'i' + n[2:], str(d1), str(d2))) for n, ((d1, d2), _) - in operator_inplace_two_arg_promoted_parametrize_inputs] +operator_inplace_two_args_promoted_parametrize_ids = ['-'.join((n[:2] + 'i' + n[2:], str(d1), str(d2))) for n, ((d1, d2), _) + in operator_inplace_two_args_promoted_parametrize_inputs] @pytest.mark.parametrize('binary_op_name,dtypes', - operator_inplace_two_arg_promoted_parametrize_inputs, - ids=operator_inplace_two_arg_promoted_parametrize_ids) -@given(two_shapes=two_broadcastable_shapes(), fillvalues=data()) -def test_operator_inplace_two_arg_promoted_promotion(binary_op_name, dtypes, two_shapes, - fillvalues): + operator_inplace_two_args_promoted_parametrize_inputs, + ids=operator_inplace_two_args_promoted_parametrize_ids) +@given(two_shapes=two_broadcastable_shapes(), data=data()) +def test_operator_inplace_two_args_promoted_promotion(binary_op_name, dtypes, two_shapes, + data): binary_op = binary_operators[binary_op_name] (dtype1, dtype2), res_dtype = dtypes - fillvalue1 = fillvalues.draw(scalars(just(dtype1))) + fillvalue1 = data.draw(scalars(just(dtype1))) if binary_op_name in ['>>', '<<']: - fillvalue2 = fillvalues.draw(scalars(just(dtype2)).filter(lambda x: x > 0)) + fillvalue2 = data.draw(scalars(just(dtype2)).filter(lambda x: x > 0)) else: - fillvalue2 = fillvalues.draw(scalars(just(dtype2))) + fillvalue2 = data.draw(scalars(just(dtype2))) for i in [dtype1, dtype2, res_dtype]: if isinstance(i, _array_module._UndefinedStub): @@ -377,15 +327,15 @@ def test_operator_inplace_two_arg_promoted_promotion(binary_op_name, dtypes, two scalar_promotion_parametrize_inputs = [ pytest.param(binary_op_name, dtype, scalar_type, id=f"{binary_op_name}-{dtype}-{scalar_type.__name__}") for binary_op_name in sorted(set(binary_operators) - {'__matmul__'}) - for dtype in input_types[elementwise_function_input_types[operators_to_functions[binary_op_name]]] + for dtype in category_to_dtypes[func_in_categories[operators_to_functions[binary_op_name]]] for scalar_type in dtypes_to_scalars[dtype] ] @pytest.mark.parametrize('binary_op_name,dtype,scalar_type', scalar_promotion_parametrize_inputs) -@given(shape=shapes, python_scalars=data(), fillvalues=data()) +@given(shape=shapes, python_scalars=data(), data=data()) def test_operator_scalar_promotion(binary_op_name, dtype, scalar_type, - shape, python_scalars, fillvalues): + shape, python_scalars, data): """ See https://data-apis.github.io/array-api/latest/API_specification/type_promotion.html#mixing-arrays-with-python-scalars """ @@ -393,14 +343,14 @@ def test_operator_scalar_promotion(binary_op_name, dtype, scalar_type, if binary_op == '@': pytest.skip("matmul (@) is not supported for scalars") - if dtype in input_types['integer']: + if dtype in category_to_dtypes['integer']: s = python_scalars.draw(integers(*dtype_ranges[dtype])) else: s = python_scalars.draw(from_type(scalar_type)) scalar_as_array = _array_module.asarray(s, dtype=dtype) get_locals = lambda: dict(a=a, s=s, scalar_as_array=scalar_as_array) - fillvalue = fillvalues.draw(scalars(just(dtype))) + fillvalue = data.draw(scalars(just(dtype))) a = full(shape, fillvalue, dtype=dtype) # As per the spec: From 19cfff7984e2fc79bcb9d00033590e1fb64e2326 Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Mon, 11 Oct 2021 14:01:24 +0100 Subject: [PATCH 10/41] Rudimentary operator parameters refactor --- array_api_tests/dtype_helpers.py | 12 +-- array_api_tests/test_signatures.py | 6 +- array_api_tests/test_type_promotion.py | 133 ++++++++++++++----------- 3 files changed, 82 insertions(+), 69 deletions(-) diff --git a/array_api_tests/dtype_helpers.py b/array_api_tests/dtype_helpers.py index d27862cd..384195c3 100644 --- a/array_api_tests/dtype_helpers.py +++ b/array_api_tests/dtype_helpers.py @@ -9,9 +9,9 @@ 'dtype_signed', 'func_in_categories', 'func_out_categories', - 'binary_operators', - 'unary_operators', - 'operators_to_functions', + 'binary_op_to_symbol', + 'unary_op_to_symbol', + 'op_to_func', ] @@ -223,7 +223,7 @@ } -binary_operators = { +binary_op_to_symbol = { '__add__': '+', '__and__': '&', '__eq__': '==', @@ -246,7 +246,7 @@ } -unary_operators = { +unary_op_to_symbol = { '__abs__': 'abs()', '__invert__': '~', '__neg__': '-', @@ -254,7 +254,7 @@ } -operators_to_functions = { +op_to_func = { '__abs__': 'abs', '__add__': 'add', '__and__': 'bitwise_and', diff --git a/array_api_tests/test_signatures.py b/array_api_tests/test_signatures.py index 964b1c4e..2f11729b 100644 --- a/array_api_tests/test_signatures.py +++ b/array_api_tests/test_signatures.py @@ -4,7 +4,7 @@ from ._array_module import mod, mod_name, ones, eye, float64, bool, int64 from .pytest_helpers import raises, doesnt_raise -from .dtype_helpers import func_in_categories, operators_to_functions +from .dtype_helpers import func_in_categories, op_to_func from . import function_stubs @@ -160,9 +160,9 @@ def test_function_positional_args(name): dtype = None if (name.startswith('__i') and name not in ['__int__', '__invert__', '__index__'] or name.startswith('__r') and name != '__rshift__'): - n = operators_to_functions[name[:2] + name[3:]] + n = op_to_func[name[:2] + name[3:]] else: - n = operators_to_functions.get(name, name) + n = op_to_func.get(name, name) if 'boolean' in func_in_categories.get(n, 'floating'): dtype = bool elif 'integer' in func_in_categories.get(n, 'floating'): diff --git a/array_api_tests/test_type_promotion.py b/array_api_tests/test_type_promotion.py index 7eb60faf..12b0402b 100644 --- a/array_api_tests/test_type_promotion.py +++ b/array_api_tests/test_type_promotion.py @@ -24,32 +24,57 @@ dtypes_to_scalars, func_in_categories, func_out_categories, - binary_operators, - unary_operators, - operators_to_functions, + binary_op_to_symbol, + unary_op_to_symbol, + op_to_func, ) def generate_params( + func_family: Literal['elementwise', 'operator'], in_nargs: int, out_category: Literal['bool', 'promoted'], ) -> Iterator: - funcs = [ - f for f in elementwise_functions.__all__ - if nargs(f) == in_nargs and func_out_categories[f] == out_category - ] - if in_nargs == 1: - for func in funcs: - in_category = func_in_categories[func] - for in_dtype in category_to_dtypes[in_category]: - yield pytest.param(func, in_dtype, id=f"{func}({in_dtype})") + if func_family == 'elementwise': + funcs = [ + f for f in elementwise_functions.__all__ + if nargs(f) == in_nargs and func_out_categories[f] == out_category + ] + if in_nargs == 1: + for func in funcs: + in_category = func_in_categories[func] + for in_dtype in category_to_dtypes[in_category]: + yield pytest.param(func, in_dtype, id=f"{func}({in_dtype})") + else: + for func, ((d1, d2), d3) in product(funcs, promotion_table.items()): + if all(d in category_to_dtypes[func_in_categories[func]] for d in (d1, d2)): + if out_category == 'bool': + yield pytest.param(func, (d1, d2), id=f"{func}({d1}, {d2})") + else: + yield pytest.param(func, ((d1, d2), d3), id=f"{func}({d1}, {d2}) -> {d3}") else: - for func, ((d1, d2), d3) in product(funcs, promotion_table.items()): - if all(d in category_to_dtypes[func_in_categories[func]] for d in (d1, d2)): - if out_category == 'bool': - yield pytest.param(func, (d1, d2), id=f"{func}({d1}, {d2})") - else: - yield pytest.param(func, ((d1, d2), d3), id=f"{func}({d1}, {d2}) -> {d3}") + if in_nargs == 1: + for op, symbol in unary_op_to_symbol.items(): + func = op_to_func[op] + if func_out_categories[func] == out_category: + in_category = func_in_categories[func] + for in_dtype in category_to_dtypes[in_category]: + yield pytest.param(op, symbol, in_dtype, id=f"{op}({in_dtype})") + else: + for op, symbol in binary_op_to_symbol.items(): + if op == "__matmul__": + continue + func = op_to_func[op] + if func_out_categories[func] == out_category: + in_category = func_in_categories[func] + for ((d1, d2), d3) in promotion_table.items(): + if all(d in category_to_dtypes[in_category] for d in (d1, d2)): + if out_category == 'bool': + yield pytest.param(op, symbol, (d1, d2), id=f"{op}({d1}, {d2})") + else: + if d1 == d3: + yield pytest.param(op, symbol, ((d1, d2), d3), id=f"{op}({d1}, {d2}) -> {d3}") + # TODO: These functions should still do type promotion internally, but we do @@ -59,7 +84,7 @@ def generate_params( # array(1.00000001, dtype=float64)) will be wrong if the float64 array is # downcast to float32. See for instance # https://github.com/numpy/numpy/issues/10322. -@pytest.mark.parametrize('func, dtypes', generate_params(in_nargs=2, out_category='bool')) +@pytest.mark.parametrize('func, dtypes', generate_params('elementwise', in_nargs=2, out_category='bool')) # The spec explicitly requires type promotion to work for shape 0 # Unfortunately, data(), isn't compatible with @example, so this is commented # out for now. @@ -91,7 +116,7 @@ def test_elementwise_two_args_bool_type_promotion(func, two_shapes, dtypes, data # TODO: Extend this to all functions (not just elementwise), and handle # functions that take more than 2 args -@pytest.mark.parametrize('func, dtypes', generate_params(in_nargs=2, out_category='promoted')) +@pytest.mark.parametrize('func, dtypes', generate_params('elementwise', in_nargs=2, out_category='promoted')) # The spec explicitly requires type promotion to work for shape 0 # Unfortunately, data(), isn't compatible with @example, so this is commented # out for now. @@ -124,7 +149,7 @@ def test_elementwise_two_args_promoted_type_promotion(func, # TODO: Extend this to all functions (not just elementwise), and handle # functions that take more than 2 args -@pytest.mark.parametrize('func, dtype', generate_params(in_nargs=1, out_category='bool')) +@pytest.mark.parametrize('func, dtype', generate_params('elementwise', in_nargs=1, out_category='bool')) # The spec explicitly requires type promotion to work for shape 0 # Unfortunately, data(), isn't compatible with @example, so this is commented # out for now. @@ -147,7 +172,7 @@ def test_elementwise_one_arg_bool(func, shape, dtype, data): # TODO: Extend this to all functions (not just elementwise), and handle # functions that take more than 2 args -@pytest.mark.parametrize('func,dtype', generate_params(in_nargs=1, out_category='promoted')) +@pytest.mark.parametrize('func,dtype', generate_params('elementwise', in_nargs=1, out_category='promoted')) # The spec explicitly requires type promotion to work for shape 0 # Unfortunately, data(), isn't compatible with @example, so this is commented # out for now. @@ -169,11 +194,11 @@ def test_elementwise_one_arg_type_promotion(func, shape, assert res.dtype == dtype, f"{func}({dtype}) returned to {res.dtype}, should have promoted to {dtype} (shape={shape})" -unary_operators_promoted = [unary_op_name for unary_op_name in sorted(unary_operators) - if func_out_categories[operators_to_functions[unary_op_name]] == 'promoted'] +unary_operators_promoted = [unary_op_name for unary_op_name in sorted(unary_op_to_symbol) + if func_out_categories[op_to_func[unary_op_name]] == 'promoted'] operator_one_arg_promoted_parametrize_inputs = [(unary_op_name, dtypes) for unary_op_name in unary_operators_promoted - for dtypes in category_to_dtypes[func_in_categories[operators_to_functions[unary_op_name]]] + for dtypes in category_to_dtypes[func_in_categories[op_to_func[unary_op_name]]] ] operator_one_arg_promoted_parametrize_ids = [f"{n}-{d}" for n, d in operator_one_arg_promoted_parametrize_inputs] @@ -181,17 +206,16 @@ def test_elementwise_one_arg_type_promotion(func, shape, # TODO: Extend this to all functions (not just elementwise), and handle # functions that take more than 2 args -@pytest.mark.parametrize('unary_op_name,dtype', - operator_one_arg_promoted_parametrize_inputs, - ids=operator_one_arg_promoted_parametrize_ids) +@pytest.mark.parametrize( + 'unary_op_name, unary_op, dtype', + generate_params('operator', in_nargs=1, out_category='promoted'), +) # The spec explicitly requires type promotion to work for shape 0 # Unfortunately, data(), isn't compatible with @example, so this is commented # out for now. # @example(shape=(0,)) @given(shape=shapes, data=data()) -def test_operator_one_arg_type_promotion(unary_op_name, shape, dtype, data): - unary_op = unary_operators[unary_op_name] - +def test_operator_one_arg_type_promotion(unary_op_name, unary_op, shape, dtype, data): fillvalue = data.draw(scalars(just(dtype))) if isinstance(dtype, _array_module._UndefinedStub): @@ -211,24 +235,22 @@ def test_operator_one_arg_type_promotion(unary_op_name, shape, dtype, data): assert res.dtype == dtype, f"{unary_op}({dtype}) returned to {res.dtype}, should have promoted to {dtype} (shape={shape})" # Note: the boolean binary operators do not have reversed or in-place variants -binary_operators_bool = [binary_op_name for binary_op_name in sorted(set(binary_operators) - {'__matmul__'}) - if func_out_categories[operators_to_functions[binary_op_name]] == 'bool'] +binary_operators_bool = [binary_op_name for binary_op_name in sorted(set(binary_op_to_symbol) - {'__matmul__'}) + if func_out_categories[op_to_func[binary_op_name]] == 'bool'] operator_two_args_bool_parametrize_inputs = [(binary_op_name, dtypes) for binary_op_name in binary_operators_bool for dtypes in promotion_table.keys() - if all(d in category_to_dtypes[func_in_categories[operators_to_functions[binary_op_name]]] for d in dtypes) + if all(d in category_to_dtypes[func_in_categories[op_to_func[binary_op_name]]] for d in dtypes) ] operator_two_args_bool_parametrize_ids = [f"{n}-{d1}-{d2}" for n, (d1, d2) in operator_two_args_bool_parametrize_inputs] -@pytest.mark.parametrize('binary_op_name,dtypes', - operator_two_args_bool_parametrize_inputs, - ids=operator_two_args_bool_parametrize_ids) +@pytest.mark.parametrize( + 'binary_op_name, binary_op, dtypes', + generate_params('operator', in_nargs=2, out_category='bool') +) @given(two_shapes=two_mutually_broadcastable_shapes, data=data()) -def test_operator_two_args_bool_promotion(binary_op_name, dtypes, two_shapes, - data): - binary_op = binary_operators[binary_op_name] - +def test_operator_two_args_bool_promotion(binary_op_name, binary_op, dtypes, two_shapes, data): dtype1, dtype2 = dtypes fillvalue1 = data.draw(scalars(just(dtype1))) fillvalue2 = data.draw(scalars(just(dtype2))) @@ -247,24 +269,19 @@ def test_operator_two_args_bool_promotion(binary_op_name, dtypes, two_shapes, assert res.dtype == bool_dtype, f"{dtype1} {binary_op} {dtype2} promoted to {res.dtype}, should have promoted to bool (shape={shape1, shape2})" -binary_operators_promoted = [binary_op_name for binary_op_name in sorted(set(binary_operators) - {'__matmul__'}) - if func_out_categories[operators_to_functions[binary_op_name]] == 'promoted'] +binary_operators_promoted = [binary_op_name for binary_op_name in sorted(set(binary_op_to_symbol) - {'__matmul__'}) + if func_out_categories[op_to_func[binary_op_name]] == 'promoted'] operator_two_args_promoted_parametrize_inputs = [(binary_op_name, dtypes) for binary_op_name in binary_operators_promoted for dtypes in promotion_table.items() - if all(d in category_to_dtypes[func_in_categories[operators_to_functions[binary_op_name]]] for d in dtypes[0]) + if all(d in category_to_dtypes[func_in_categories[op_to_func[binary_op_name]]] for d in dtypes[0]) ] operator_two_args_promoted_parametrize_ids = [f"{n}-{d1}-{d2}" for n, ((d1, d2), _) in operator_two_args_promoted_parametrize_inputs] -@pytest.mark.parametrize('binary_op_name,dtypes', - operator_two_args_promoted_parametrize_inputs, - ids=operator_two_args_promoted_parametrize_ids) +@pytest.mark.parametrize('binary_op_name, binary_op, dtypes', generate_params('operator', in_nargs=2, out_category='promoted')) @given(two_shapes=two_mutually_broadcastable_shapes, data=data()) -def test_operator_two_args_promoted_promotion(binary_op_name, dtypes, two_shapes, - data): - binary_op = binary_operators[binary_op_name] - +def test_operator_two_args_promoted_promotion(binary_op_name, binary_op, dtypes, two_shapes, data): (dtype1, dtype2), res_dtype = dtypes fillvalue1 = data.draw(scalars(just(dtype1))) if binary_op_name in ['>>', '<<']: @@ -292,14 +309,10 @@ def test_operator_two_args_promoted_promotion(binary_op_name, dtypes, two_shapes operator_inplace_two_args_promoted_parametrize_ids = ['-'.join((n[:2] + 'i' + n[2:], str(d1), str(d2))) for n, ((d1, d2), _) in operator_inplace_two_args_promoted_parametrize_inputs] -@pytest.mark.parametrize('binary_op_name,dtypes', - operator_inplace_two_args_promoted_parametrize_inputs, - ids=operator_inplace_two_args_promoted_parametrize_ids) +@pytest.mark.parametrize('binary_op_name, binary_op, dtypes', generate_params('operator', in_nargs=2, out_category='promoted')) @given(two_shapes=two_broadcastable_shapes(), data=data()) -def test_operator_inplace_two_args_promoted_promotion(binary_op_name, dtypes, two_shapes, +def test_operator_inplace_two_args_promoted_promotion(binary_op_name, binary_op, dtypes, two_shapes, data): - binary_op = binary_operators[binary_op_name] - (dtype1, dtype2), res_dtype = dtypes fillvalue1 = data.draw(scalars(just(dtype1))) if binary_op_name in ['>>', '<<']: @@ -326,8 +339,8 @@ def test_operator_inplace_two_args_promoted_promotion(binary_op_name, dtypes, tw scalar_promotion_parametrize_inputs = [ pytest.param(binary_op_name, dtype, scalar_type, id=f"{binary_op_name}-{dtype}-{scalar_type.__name__}") - for binary_op_name in sorted(set(binary_operators) - {'__matmul__'}) - for dtype in category_to_dtypes[func_in_categories[operators_to_functions[binary_op_name]]] + for binary_op_name in sorted(set(binary_op_to_symbol) - {'__matmul__'}) + for dtype in category_to_dtypes[func_in_categories[op_to_func[binary_op_name]]] for scalar_type in dtypes_to_scalars[dtype] ] @@ -339,7 +352,7 @@ def test_operator_scalar_promotion(binary_op_name, dtype, scalar_type, """ See https://data-apis.github.io/array-api/latest/API_specification/type_promotion.html#mixing-arrays-with-python-scalars """ - binary_op = binary_operators[binary_op_name] + binary_op = binary_op_to_symbol[binary_op_name] if binary_op == '@': pytest.skip("matmul (@) is not supported for scalars") From 1c06936c0c9c2b24f9e5e0aea7bfe5063cc0f95e Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Mon, 11 Oct 2021 16:46:58 +0100 Subject: [PATCH 11/41] Namespace some imports --- array_api_tests/test_type_promotion.py | 232 +++++++++++-------------- 1 file changed, 100 insertions(+), 132 deletions(-) diff --git a/array_api_tests/test_type_promotion.py b/array_api_tests/test_type_promotion.py index 12b0402b..30eff2fb 100644 --- a/array_api_tests/test_type_promotion.py +++ b/array_api_tests/test_type_promotion.py @@ -3,33 +3,21 @@ """ from itertools import product -from typing import Literal, Iterator +from typing import Iterator, Literal import pytest - from hypothesis import given -from hypothesis.strategies import from_type, data, integers, just - -from .hypothesis_helpers import (shapes, two_mutually_broadcastable_shapes, - two_broadcastable_shapes, scalars) -from .pytest_helpers import nargs -from .array_helpers import assert_exactly_equal, dtype_ranges +from hypothesis import strategies as st +from . import _array_module as xp +from . import array_helpers as ah +from . import dtype_helpers as dh +from . import hypothesis_helpers as hh from .function_stubs import elementwise_functions -from ._array_module import (full, bool as bool_dtype) -from . import _array_module -from .dtype_helpers import ( - promotion_table, - category_to_dtypes, - dtypes_to_scalars, - func_in_categories, - func_out_categories, - binary_op_to_symbol, - unary_op_to_symbol, - op_to_func, -) +from .pytest_helpers import nargs +# Note: the boolean binary operators do not have reversed or in-place variants def generate_params( func_family: Literal['elementwise', 'operator'], in_nargs: int, @@ -38,37 +26,37 @@ def generate_params( if func_family == 'elementwise': funcs = [ f for f in elementwise_functions.__all__ - if nargs(f) == in_nargs and func_out_categories[f] == out_category + if nargs(f) == in_nargs and dh.func_out_categories[f] == out_category ] if in_nargs == 1: for func in funcs: - in_category = func_in_categories[func] - for in_dtype in category_to_dtypes[in_category]: + in_category = dh.func_in_categories[func] + for in_dtype in dh.category_to_dtypes[in_category]: yield pytest.param(func, in_dtype, id=f"{func}({in_dtype})") else: - for func, ((d1, d2), d3) in product(funcs, promotion_table.items()): - if all(d in category_to_dtypes[func_in_categories[func]] for d in (d1, d2)): + for func, ((d1, d2), d3) in product(funcs, dh.promotion_table.items()): + if all(d in dh.category_to_dtypes[dh.func_in_categories[func]] for d in (d1, d2)): if out_category == 'bool': yield pytest.param(func, (d1, d2), id=f"{func}({d1}, {d2})") else: yield pytest.param(func, ((d1, d2), d3), id=f"{func}({d1}, {d2}) -> {d3}") else: if in_nargs == 1: - for op, symbol in unary_op_to_symbol.items(): - func = op_to_func[op] - if func_out_categories[func] == out_category: - in_category = func_in_categories[func] - for in_dtype in category_to_dtypes[in_category]: + for op, symbol in dh.unary_op_to_symbol.items(): + func = dh.op_to_func[op] + if dh.func_out_categories[func] == out_category: + in_category = dh.func_in_categories[func] + for in_dtype in dh.category_to_dtypes[in_category]: yield pytest.param(op, symbol, in_dtype, id=f"{op}({in_dtype})") else: - for op, symbol in binary_op_to_symbol.items(): + for op, symbol in dh.binary_op_to_symbol.items(): if op == "__matmul__": continue - func = op_to_func[op] - if func_out_categories[func] == out_category: - in_category = func_in_categories[func] - for ((d1, d2), d3) in promotion_table.items(): - if all(d in category_to_dtypes[in_category] for d in (d1, d2)): + func = dh.op_to_func[op] + if dh.func_out_categories[func] == out_category: + in_category = dh.func_in_categories[func] + for ((d1, d2), d3) in dh.promotion_table.items(): + if all(d in dh.category_to_dtypes[in_category] for d in (d1, d2)): if out_category == 'bool': yield pytest.param(op, symbol, (d1, d2), id=f"{op}({d1}, {d2})") else: @@ -89,30 +77,30 @@ def generate_params( # Unfortunately, data(), isn't compatible with @example, so this is commented # out for now. # @example(shape=(0,)) -@given(two_shapes=two_mutually_broadcastable_shapes, data=data()) +@given(two_shapes=hh.two_mutually_broadcastable_shapes, data=st.data()) def test_elementwise_two_args_bool_type_promotion(func, two_shapes, dtypes, data): assert nargs(func) == 2 - func = getattr(_array_module, func) + func = getattr(xp, func) dtype1, dtype2 = dtypes - fillvalue1 = data.draw(scalars(just(dtype1))) + fillvalue1 = data.draw(hh.scalars(st.just(dtype1))) if func in ['bitwise_left_shift', 'bitwise_right_shift']: - fillvalue2 = data.draw(scalars(just(dtype2)).filter(lambda x: x > 0)) + fillvalue2 = data.draw(hh.scalars(st.just(dtype2)).filter(lambda x: x > 0)) else: - fillvalue2 = data.draw(scalars(just(dtype2))) + fillvalue2 = data.draw(hh.scalars(st.just(dtype2))) for i in [func, dtype1, dtype2]: - if isinstance(i, _array_module._UndefinedStub): + if isinstance(i, xp._UndefinedStub): i._raise() shape1, shape2 = two_shapes - a1 = full(shape1, fillvalue1, dtype=dtype1) - a2 = full(shape2, fillvalue2, dtype=dtype2) + a1 = ah.full(shape1, fillvalue1, dtype=dtype1) + a2 = ah.full(shape2, fillvalue2, dtype=dtype2) res = func(a1, a2) - assert res.dtype == bool_dtype, f"{func}({dtype1}, {dtype2}) promoted to {res.dtype}, should have promoted to bool (shapes={shape1, shape2})" + assert res.dtype == xp.bool, f"{func}({dtype1}, {dtype2}) promoted to {res.dtype}, should have promoted to bool (shapes={shape1, shape2})" # TODO: Extend this to all functions (not just elementwise), and handle # functions that take more than 2 args @@ -121,28 +109,28 @@ def test_elementwise_two_args_bool_type_promotion(func, two_shapes, dtypes, data # Unfortunately, data(), isn't compatible with @example, so this is commented # out for now. # @example(shape=(0,)) -@given(two_shapes=two_mutually_broadcastable_shapes, data=data()) +@given(two_shapes=hh.two_mutually_broadcastable_shapes, data=st.data()) def test_elementwise_two_args_promoted_type_promotion(func, two_shapes, dtypes, data): assert nargs(func) == 2 - func = getattr(_array_module, func) + func = getattr(xp, func) (dtype1, dtype2), res_dtype = dtypes - fillvalue1 = data.draw(scalars(just(dtype1))) + fillvalue1 = data.draw(hh.scalars(st.just(dtype1))) if func in ['bitwise_left_shift', 'bitwise_right_shift']: - fillvalue2 = data.draw(scalars(just(dtype2)).filter(lambda x: x > 0)) + fillvalue2 = data.draw(hh.scalars(st.just(dtype2)).filter(lambda x: x > 0)) else: - fillvalue2 = data.draw(scalars(just(dtype2))) + fillvalue2 = data.draw(hh.scalars(st.just(dtype2))) for i in [func, dtype1, dtype2, res_dtype]: - if isinstance(i, _array_module._UndefinedStub): + if isinstance(i, xp._UndefinedStub): i._raise() shape1, shape2 = two_shapes - a1 = full(shape1, fillvalue1, dtype=dtype1) - a2 = full(shape2, fillvalue2, dtype=dtype2) + a1 = ah.full(shape1, fillvalue1, dtype=dtype1) + a2 = ah.full(shape2, fillvalue2, dtype=dtype2) res = func(a1, a2) assert res.dtype == res_dtype, f"{func}({dtype1}, {dtype2}) promoted to {res.dtype}, should have promoted to {res_dtype} (shapes={shape1, shape2})" @@ -154,21 +142,21 @@ def test_elementwise_two_args_promoted_type_promotion(func, # Unfortunately, data(), isn't compatible with @example, so this is commented # out for now. # @example(shape=(0,)) -@given(shape=shapes, data=data()) +@given(shape=hh.shapes, data=st.data()) def test_elementwise_one_arg_bool(func, shape, dtype, data): assert nargs(func) == 1 - func = getattr(_array_module, func) + func = getattr(xp, func) - fillvalue = data.draw(scalars(just(dtype))) + fillvalue = data.draw(hh.scalars(st.just(dtype))) for i in [func, dtype]: - if isinstance(i, _array_module._UndefinedStub): + if isinstance(i, xp._UndefinedStub): i._raise() - x = full(shape, fillvalue, dtype=dtype) + x = ah.full(shape, fillvalue, dtype=dtype) res = func(x) - assert res.dtype == bool_dtype, f"{func}({dtype}) returned to {res.dtype}, should have promoted to bool (shape={shape})" + assert res.dtype == xp.bool, f"{func}({dtype}) returned to {res.dtype}, should have promoted to bool (shape={shape})" # TODO: Extend this to all functions (not just elementwise), and handle # functions that take more than 2 args @@ -177,32 +165,23 @@ def test_elementwise_one_arg_bool(func, shape, dtype, data): # Unfortunately, data(), isn't compatible with @example, so this is commented # out for now. # @example(shape=(0,)) -@given(shape=shapes, data=data()) +@given(shape=hh.shapes, data=st.data()) def test_elementwise_one_arg_type_promotion(func, shape, dtype, data): assert nargs(func) == 1 - func = getattr(_array_module, func) + func = getattr(xp, func) - fillvalue = data.draw(scalars(just(dtype))) + fillvalue = data.draw(hh.scalars(st.just(dtype))) for i in [func, dtype]: - if isinstance(i, _array_module._UndefinedStub): + if isinstance(i, xp._UndefinedStub): i._raise() - x = full(shape, fillvalue, dtype=dtype) + x = ah.full(shape, fillvalue, dtype=dtype) res = func(x) assert res.dtype == dtype, f"{func}({dtype}) returned to {res.dtype}, should have promoted to {dtype} (shape={shape})" -unary_operators_promoted = [unary_op_name for unary_op_name in sorted(unary_op_to_symbol) - if func_out_categories[op_to_func[unary_op_name]] == 'promoted'] -operator_one_arg_promoted_parametrize_inputs = [(unary_op_name, dtypes) - for unary_op_name in unary_operators_promoted - for dtypes in category_to_dtypes[func_in_categories[op_to_func[unary_op_name]]] - ] -operator_one_arg_promoted_parametrize_ids = [f"{n}-{d}" for n, d - in operator_one_arg_promoted_parametrize_inputs] - # TODO: Extend this to all functions (not just elementwise), and handle # functions that take more than 2 args @@ -214,14 +193,14 @@ def test_elementwise_one_arg_type_promotion(func, shape, # Unfortunately, data(), isn't compatible with @example, so this is commented # out for now. # @example(shape=(0,)) -@given(shape=shapes, data=data()) +@given(shape=hh.shapes, data=st.data()) def test_operator_one_arg_type_promotion(unary_op_name, unary_op, shape, dtype, data): - fillvalue = data.draw(scalars(just(dtype))) + fillvalue = data.draw(hh.scalars(st.just(dtype))) - if isinstance(dtype, _array_module._UndefinedStub): + if isinstance(dtype, xp._UndefinedStub): dtype._raise() - a = full(shape, fillvalue, dtype=dtype) + a = ah.full(shape, fillvalue, dtype=dtype) get_locals = lambda: dict(a=a) @@ -234,69 +213,58 @@ def test_operator_one_arg_type_promotion(unary_op_name, unary_op, shape, dtype, assert res.dtype == dtype, f"{unary_op}({dtype}) returned to {res.dtype}, should have promoted to {dtype} (shape={shape})" -# Note: the boolean binary operators do not have reversed or in-place variants -binary_operators_bool = [binary_op_name for binary_op_name in sorted(set(binary_op_to_symbol) - {'__matmul__'}) - if func_out_categories[op_to_func[binary_op_name]] == 'bool'] -operator_two_args_bool_parametrize_inputs = [(binary_op_name, dtypes) - for binary_op_name in binary_operators_bool - for dtypes in promotion_table.keys() - if all(d in category_to_dtypes[func_in_categories[op_to_func[binary_op_name]]] for d in dtypes) - ] -operator_two_args_bool_parametrize_ids = [f"{n}-{d1}-{d2}" for n, (d1, d2) - in operator_two_args_bool_parametrize_inputs] - @pytest.mark.parametrize( 'binary_op_name, binary_op, dtypes', generate_params('operator', in_nargs=2, out_category='bool') ) -@given(two_shapes=two_mutually_broadcastable_shapes, data=data()) +@given(two_shapes=hh.two_mutually_broadcastable_shapes, data=st.data()) def test_operator_two_args_bool_promotion(binary_op_name, binary_op, dtypes, two_shapes, data): dtype1, dtype2 = dtypes - fillvalue1 = data.draw(scalars(just(dtype1))) - fillvalue2 = data.draw(scalars(just(dtype2))) + fillvalue1 = data.draw(hh.scalars(st.just(dtype1))) + fillvalue2 = data.draw(hh.scalars(st.just(dtype2))) for i in [dtype1, dtype2]: - if isinstance(i, _array_module._UndefinedStub): + if isinstance(i, xp._UndefinedStub): i._raise() shape1, shape2 = two_shapes - a1 = full(shape1, fillvalue1, dtype=dtype1) - a2 = full(shape2, fillvalue2, dtype=dtype2) + a1 = ah.full(shape1, fillvalue1, dtype=dtype1) + a2 = ah.full(shape2, fillvalue2, dtype=dtype2) get_locals = lambda: dict(a1=a1, a2=a2) expression = f'a1 {binary_op} a2' res = eval(expression, get_locals()) - assert res.dtype == bool_dtype, f"{dtype1} {binary_op} {dtype2} promoted to {res.dtype}, should have promoted to bool (shape={shape1, shape2})" + assert res.dtype == xp.bool, f"{dtype1} {binary_op} {dtype2} promoted to {res.dtype}, should have promoted to bool (shape={shape1, shape2})" -binary_operators_promoted = [binary_op_name for binary_op_name in sorted(set(binary_op_to_symbol) - {'__matmul__'}) - if func_out_categories[op_to_func[binary_op_name]] == 'promoted'] +binary_operators_promoted = [binary_op_name for binary_op_name in sorted(set(dh.binary_op_to_symbol) - {'__matmul__'}) + if dh.func_out_categories[dh.op_to_func[binary_op_name]] == 'promoted'] operator_two_args_promoted_parametrize_inputs = [(binary_op_name, dtypes) for binary_op_name in binary_operators_promoted - for dtypes in promotion_table.items() - if all(d in category_to_dtypes[func_in_categories[op_to_func[binary_op_name]]] for d in dtypes[0]) + for dtypes in dh.promotion_table.items() + if all(d in dh.category_to_dtypes[dh.func_in_categories[dh.op_to_func[binary_op_name]]] for d in dtypes[0]) ] operator_two_args_promoted_parametrize_ids = [f"{n}-{d1}-{d2}" for n, ((d1, d2), _) in operator_two_args_promoted_parametrize_inputs] @pytest.mark.parametrize('binary_op_name, binary_op, dtypes', generate_params('operator', in_nargs=2, out_category='promoted')) -@given(two_shapes=two_mutually_broadcastable_shapes, data=data()) +@given(two_shapes=hh.two_mutually_broadcastable_shapes, data=st.data()) def test_operator_two_args_promoted_promotion(binary_op_name, binary_op, dtypes, two_shapes, data): (dtype1, dtype2), res_dtype = dtypes - fillvalue1 = data.draw(scalars(just(dtype1))) + fillvalue1 = data.draw(hh.scalars(st.just(dtype1))) if binary_op_name in ['>>', '<<']: - fillvalue2 = data.draw(scalars(just(dtype2)).filter(lambda x: x > 0)) + fillvalue2 = data.draw(hh.scalars(st.just(dtype2)).filter(lambda x: x > 0)) else: - fillvalue2 = data.draw(scalars(just(dtype2))) + fillvalue2 = data.draw(hh.scalars(st.just(dtype2))) for i in [dtype1, dtype2, res_dtype]: - if isinstance(i, _array_module._UndefinedStub): + if isinstance(i, xp._UndefinedStub): i._raise() shape1, shape2 = two_shapes - a1 = full(shape1, fillvalue1, dtype=dtype1) - a2 = full(shape2, fillvalue2, dtype=dtype2) + a1 = ah.full(shape1, fillvalue1, dtype=dtype1) + a2 = ah.full(shape2, fillvalue2, dtype=dtype2) get_locals = lambda: dict(a1=a1, a2=a2) expression = f'a1 {binary_op} a2' @@ -310,23 +278,23 @@ def test_operator_two_args_promoted_promotion(binary_op_name, binary_op, dtypes, in operator_inplace_two_args_promoted_parametrize_inputs] @pytest.mark.parametrize('binary_op_name, binary_op, dtypes', generate_params('operator', in_nargs=2, out_category='promoted')) -@given(two_shapes=two_broadcastable_shapes(), data=data()) +@given(two_shapes=hh.two_broadcastable_shapes(), data=st.data()) def test_operator_inplace_two_args_promoted_promotion(binary_op_name, binary_op, dtypes, two_shapes, data): (dtype1, dtype2), res_dtype = dtypes - fillvalue1 = data.draw(scalars(just(dtype1))) + fillvalue1 = data.draw(hh.scalars(st.just(dtype1))) if binary_op_name in ['>>', '<<']: - fillvalue2 = data.draw(scalars(just(dtype2)).filter(lambda x: x > 0)) + fillvalue2 = data.draw(hh.scalars(st.just(dtype2)).filter(lambda x: x > 0)) else: - fillvalue2 = data.draw(scalars(just(dtype2))) + fillvalue2 = data.draw(hh.scalars(st.just(dtype2))) for i in [dtype1, dtype2, res_dtype]: - if isinstance(i, _array_module._UndefinedStub): + if isinstance(i, xp._UndefinedStub): i._raise() shape1, shape2 = two_shapes - a1 = full(shape1, fillvalue1, dtype=dtype1) - a2 = full(shape2, fillvalue2, dtype=dtype2) + a1 = ah.full(shape1, fillvalue1, dtype=dtype1) + a2 = ah.full(shape2, fillvalue2, dtype=dtype2) get_locals = lambda: dict(a1=a1, a2=a2) @@ -339,32 +307,32 @@ def test_operator_inplace_two_args_promoted_promotion(binary_op_name, binary_op, scalar_promotion_parametrize_inputs = [ pytest.param(binary_op_name, dtype, scalar_type, id=f"{binary_op_name}-{dtype}-{scalar_type.__name__}") - for binary_op_name in sorted(set(binary_op_to_symbol) - {'__matmul__'}) - for dtype in category_to_dtypes[func_in_categories[op_to_func[binary_op_name]]] - for scalar_type in dtypes_to_scalars[dtype] + for binary_op_name in sorted(set(dh.binary_op_to_symbol) - {'__matmul__'}) + for dtype in dh.category_to_dtypes[dh.func_in_categories[dh.op_to_func[binary_op_name]]] + for scalar_type in dh.dtypes_to_scalars[dtype] ] @pytest.mark.parametrize('binary_op_name,dtype,scalar_type', scalar_promotion_parametrize_inputs) -@given(shape=shapes, python_scalars=data(), data=data()) +@given(shape=hh.shapes, python_scalars=st.data(), data=st.data()) def test_operator_scalar_promotion(binary_op_name, dtype, scalar_type, shape, python_scalars, data): """ - See https://data-apis.github.io/array-api/latest/API_specification/type_promotion.html#mixing-arrays-with-python-scalars + See https://st.data-apis.github.io/array-api/latest/API_specification/type_promotion.html#mixing-arrays-with-python-hh.scalars """ - binary_op = binary_op_to_symbol[binary_op_name] + binary_op = dh.binary_op_to_symbol[binary_op_name] if binary_op == '@': - pytest.skip("matmul (@) is not supported for scalars") + pytest.skip("matmul (@) is not supported for hh.scalars") - if dtype in category_to_dtypes['integer']: - s = python_scalars.draw(integers(*dtype_ranges[dtype])) + if dtype in dh.category_to_dtypes['integer']: + s = python_scalars.draw(st.integers(*ah.dtype_ranges[dtype])) else: - s = python_scalars.draw(from_type(scalar_type)) - scalar_as_array = _array_module.asarray(s, dtype=dtype) + s = python_scalars.draw(st.from_type(scalar_type)) + scalar_as_array = ah.asarray(s, dtype=dtype) get_locals = lambda: dict(a=a, s=s, scalar_as_array=scalar_as_array) - fillvalue = data.draw(scalars(just(dtype))) - a = full(shape, fillvalue, dtype=dtype) + fillvalue = data.draw(hh.scalars(st.just(dtype))) + a = ah.full(shape, fillvalue, dtype=dtype) # As per the spec: @@ -380,30 +348,30 @@ def test_operator_scalar_promotion(binary_op_name, dtype, scalar_type, array_scalar_expected = f'a {binary_op} scalar_as_array' res = eval(array_scalar, get_locals()) expected = eval(array_scalar_expected, get_locals()) - assert_exactly_equal(res, expected) + ah.assert_exactly_equal(res, expected) scalar_array = f's {binary_op} a' scalar_array_expected = f'scalar_as_array {binary_op} a' res = eval(scalar_array, get_locals()) expected = eval(scalar_array_expected, get_locals()) - assert_exactly_equal(res, expected) + ah.assert_exactly_equal(res, expected) # Test in-place operators if binary_op in ['==', '!=', '<', '>', '<=', '>=']: return array_scalar = f'a {binary_op}= s' array_scalar_expected = f'a {binary_op}= scalar_as_array' - a = full(shape, fillvalue, dtype=dtype) + a = ah.full(shape, fillvalue, dtype=dtype) res_locals = get_locals() exec(array_scalar, get_locals()) res = res_locals['a'] - a = full(shape, fillvalue, dtype=dtype) + a = ah.full(shape, fillvalue, dtype=dtype) expected_locals = get_locals() exec(array_scalar_expected, get_locals()) expected = expected_locals['a'] - assert_exactly_equal(res, expected) + ah.assert_exactly_equal(res, expected) if __name__ == '__main__': - for (i, j), p in promotion_table.items(): + for (i, j), p in dh.promotion_table.items(): print(f"({i}, {j}) -> {p}") From 85212daba0c849c8f67fc47af308e7623ed347d5 Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Mon, 11 Oct 2021 17:07:31 +0100 Subject: [PATCH 12/41] Change promotion test names --- array_api_tests/test_type_promotion.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/array_api_tests/test_type_promotion.py b/array_api_tests/test_type_promotion.py index 30eff2fb..7f8cfd2d 100644 --- a/array_api_tests/test_type_promotion.py +++ b/array_api_tests/test_type_promotion.py @@ -78,7 +78,7 @@ def generate_params( # out for now. # @example(shape=(0,)) @given(two_shapes=hh.two_mutually_broadcastable_shapes, data=st.data()) -def test_elementwise_two_args_bool_type_promotion(func, two_shapes, dtypes, data): +def test_elementwise_two_args_return_bool(func, two_shapes, dtypes, data): assert nargs(func) == 2 func = getattr(xp, func) @@ -110,7 +110,7 @@ def test_elementwise_two_args_bool_type_promotion(func, two_shapes, dtypes, data # out for now. # @example(shape=(0,)) @given(two_shapes=hh.two_mutually_broadcastable_shapes, data=st.data()) -def test_elementwise_two_args_promoted_type_promotion(func, +def test_elementwise_two_args_return_promoted(func, two_shapes, dtypes, data): assert nargs(func) == 2 @@ -143,7 +143,7 @@ def test_elementwise_two_args_promoted_type_promotion(func, # out for now. # @example(shape=(0,)) @given(shape=hh.shapes, data=st.data()) -def test_elementwise_one_arg_bool(func, shape, dtype, data): +def test_elementwise_one_arg_return_bool(func, shape, dtype, data): assert nargs(func) == 1 func = getattr(xp, func) @@ -166,7 +166,7 @@ def test_elementwise_one_arg_bool(func, shape, dtype, data): # out for now. # @example(shape=(0,)) @given(shape=hh.shapes, data=st.data()) -def test_elementwise_one_arg_type_promotion(func, shape, +def test_elementwise_one_arg_return_promoted(func, shape, dtype, data): assert nargs(func) == 1 func = getattr(xp, func) @@ -194,7 +194,7 @@ def test_elementwise_one_arg_type_promotion(func, shape, # out for now. # @example(shape=(0,)) @given(shape=hh.shapes, data=st.data()) -def test_operator_one_arg_type_promotion(unary_op_name, unary_op, shape, dtype, data): +def test_operator_one_arg_return_promoted(unary_op_name, unary_op, shape, dtype, data): fillvalue = data.draw(hh.scalars(st.just(dtype))) if isinstance(dtype, xp._UndefinedStub): @@ -218,7 +218,7 @@ def test_operator_one_arg_type_promotion(unary_op_name, unary_op, shape, dtype, generate_params('operator', in_nargs=2, out_category='bool') ) @given(two_shapes=hh.two_mutually_broadcastable_shapes, data=st.data()) -def test_operator_two_args_bool_promotion(binary_op_name, binary_op, dtypes, two_shapes, data): +def test_operator_two_args_return_bool(binary_op_name, binary_op, dtypes, two_shapes, data): dtype1, dtype2 = dtypes fillvalue1 = data.draw(hh.scalars(st.just(dtype1))) fillvalue2 = data.draw(hh.scalars(st.just(dtype2))) @@ -249,7 +249,7 @@ def test_operator_two_args_bool_promotion(binary_op_name, binary_op, dtypes, two @pytest.mark.parametrize('binary_op_name, binary_op, dtypes', generate_params('operator', in_nargs=2, out_category='promoted')) @given(two_shapes=hh.two_mutually_broadcastable_shapes, data=st.data()) -def test_operator_two_args_promoted_promotion(binary_op_name, binary_op, dtypes, two_shapes, data): +def test_operator_two_args_return_promoted(binary_op_name, binary_op, dtypes, two_shapes, data): (dtype1, dtype2), res_dtype = dtypes fillvalue1 = data.draw(hh.scalars(st.just(dtype1))) if binary_op_name in ['>>', '<<']: @@ -279,7 +279,7 @@ def test_operator_two_args_promoted_promotion(binary_op_name, binary_op, dtypes, @pytest.mark.parametrize('binary_op_name, binary_op, dtypes', generate_params('operator', in_nargs=2, out_category='promoted')) @given(two_shapes=hh.two_broadcastable_shapes(), data=st.data()) -def test_operator_inplace_two_args_promoted_promotion(binary_op_name, binary_op, dtypes, two_shapes, +def test_operator_inplace_two_args_return_promoted(binary_op_name, binary_op, dtypes, two_shapes, data): (dtype1, dtype2), res_dtype = dtypes fillvalue1 = data.draw(hh.scalars(st.just(dtype1))) @@ -315,7 +315,7 @@ def test_operator_inplace_two_args_promoted_promotion(binary_op_name, binary_op, @pytest.mark.parametrize('binary_op_name,dtype,scalar_type', scalar_promotion_parametrize_inputs) @given(shape=hh.shapes, python_scalars=st.data(), data=st.data()) -def test_operator_scalar_promotion(binary_op_name, dtype, scalar_type, +def test_operator_scalar_arg_return_promoted(binary_op_name, dtype, scalar_type, shape, python_scalars, data): """ See https://st.data-apis.github.io/array-api/latest/API_specification/type_promotion.html#mixing-arrays-with-python-hh.scalars From ce6a7e8c4c4cd205bc80b6c3b956a6b5e9bee115 Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Mon, 11 Oct 2021 17:37:41 +0100 Subject: [PATCH 13/41] Clarify operator-related variable names --- array_api_tests/dtype_helpers.py | 18 +++--- array_api_tests/test_type_promotion.py | 87 +++++++++++--------------- 2 files changed, 47 insertions(+), 58 deletions(-) diff --git a/array_api_tests/dtype_helpers.py b/array_api_tests/dtype_helpers.py index 384195c3..f35b435c 100644 --- a/array_api_tests/dtype_helpers.py +++ b/array_api_tests/dtype_helpers.py @@ -9,9 +9,8 @@ 'dtype_signed', 'func_in_categories', 'func_out_categories', - 'binary_op_to_symbol', - 'unary_op_to_symbol', - 'op_to_func', + 'binary_func_to_op', + 'unary_func_to_op', ] @@ -223,7 +222,7 @@ } -binary_op_to_symbol = { +binary_func_to_op = { '__add__': '+', '__and__': '&', '__eq__': '==', @@ -246,7 +245,7 @@ } -unary_op_to_symbol = { +unary_func_to_op = { '__abs__': 'abs()', '__invert__': '~', '__neg__': '-', @@ -254,7 +253,7 @@ } -op_to_func = { +_operator_to_elementwise = { '__abs__': 'abs', '__add__': 'add', '__and__': 'bitwise_and', @@ -265,7 +264,7 @@ '__le__': 'less_equal', '__lshift__': 'bitwise_left_shift', '__lt__': 'less', - '__matmul__': 'matmul', + # '__matmul__': 'matmul', # TODO: support matmul '__mod__': 'remainder', '__mul__': 'multiply', '__ne__': 'not_equal', @@ -279,3 +278,8 @@ '__neg__': 'negative', '__pos__': 'positive', } + + +for op_func, elwise_func in _operator_to_elementwise.items(): + func_in_categories[op_func] = func_in_categories[elwise_func] + func_out_categories[op_func] = func_out_categories[elwise_func] diff --git a/array_api_tests/test_type_promotion.py b/array_api_tests/test_type_promotion.py index 7f8cfd2d..7fd37ab6 100644 --- a/array_api_tests/test_type_promotion.py +++ b/array_api_tests/test_type_promotion.py @@ -42,26 +42,26 @@ def generate_params( yield pytest.param(func, ((d1, d2), d3), id=f"{func}({d1}, {d2}) -> {d3}") else: if in_nargs == 1: - for op, symbol in dh.unary_op_to_symbol.items(): - func = dh.op_to_func[op] + for func, op in dh.unary_func_to_op.items(): + if func == "__matmul__": + continue if dh.func_out_categories[func] == out_category: in_category = dh.func_in_categories[func] for in_dtype in dh.category_to_dtypes[in_category]: - yield pytest.param(op, symbol, in_dtype, id=f"{op}({in_dtype})") + yield pytest.param(func, op, in_dtype, id=f"{func}({in_dtype})") else: - for op, symbol in dh.binary_op_to_symbol.items(): - if op == "__matmul__": + for func, op in dh.binary_func_to_op.items(): + if func == "__matmul__": continue - func = dh.op_to_func[op] if dh.func_out_categories[func] == out_category: in_category = dh.func_in_categories[func] for ((d1, d2), d3) in dh.promotion_table.items(): if all(d in dh.category_to_dtypes[in_category] for d in (d1, d2)): if out_category == 'bool': - yield pytest.param(op, symbol, (d1, d2), id=f"{op}({d1}, {d2})") + yield pytest.param(func, op, (d1, d2), id=f"{func}({d1}, {d2})") else: if d1 == d3: - yield pytest.param(op, symbol, ((d1, d2), d3), id=f"{op}({d1}, {d2}) -> {d3}") + yield pytest.param(func, op, ((d1, d2), d3), id=f"{func}({d1}, {d2}) -> {d3}") @@ -214,11 +214,11 @@ def test_operator_one_arg_return_promoted(unary_op_name, unary_op, shape, dtype, assert res.dtype == dtype, f"{unary_op}({dtype}) returned to {res.dtype}, should have promoted to {dtype} (shape={shape})" @pytest.mark.parametrize( - 'binary_op_name, binary_op, dtypes', + 'func, op, dtypes', generate_params('operator', in_nargs=2, out_category='bool') ) @given(two_shapes=hh.two_mutually_broadcastable_shapes, data=st.data()) -def test_operator_two_args_return_bool(binary_op_name, binary_op, dtypes, two_shapes, data): +def test_operator_two_args_return_bool(func, op, dtypes, two_shapes, data): dtype1, dtype2 = dtypes fillvalue1 = data.draw(hh.scalars(st.just(dtype1))) fillvalue2 = data.draw(hh.scalars(st.just(dtype2))) @@ -232,27 +232,17 @@ def test_operator_two_args_return_bool(binary_op_name, binary_op, dtypes, two_sh a2 = ah.full(shape2, fillvalue2, dtype=dtype2) get_locals = lambda: dict(a1=a1, a2=a2) - expression = f'a1 {binary_op} a2' + expression = f'a1 {op} a2' res = eval(expression, get_locals()) - assert res.dtype == xp.bool, f"{dtype1} {binary_op} {dtype2} promoted to {res.dtype}, should have promoted to bool (shape={shape1, shape2})" - -binary_operators_promoted = [binary_op_name for binary_op_name in sorted(set(dh.binary_op_to_symbol) - {'__matmul__'}) - if dh.func_out_categories[dh.op_to_func[binary_op_name]] == 'promoted'] -operator_two_args_promoted_parametrize_inputs = [(binary_op_name, dtypes) - for binary_op_name in binary_operators_promoted - for dtypes in dh.promotion_table.items() - if all(d in dh.category_to_dtypes[dh.func_in_categories[dh.op_to_func[binary_op_name]]] for d in dtypes[0]) - ] -operator_two_args_promoted_parametrize_ids = [f"{n}-{d1}-{d2}" for n, ((d1, d2), _) - in operator_two_args_promoted_parametrize_inputs] + assert res.dtype == xp.bool, f"{dtype1} {op} {dtype2} promoted to {res.dtype}, should have promoted to bool (shape={shape1, shape2})" -@pytest.mark.parametrize('binary_op_name, binary_op, dtypes', generate_params('operator', in_nargs=2, out_category='promoted')) +@pytest.mark.parametrize('func, op, dtypes', generate_params('operator', in_nargs=2, out_category='promoted')) @given(two_shapes=hh.two_mutually_broadcastable_shapes, data=st.data()) -def test_operator_two_args_return_promoted(binary_op_name, binary_op, dtypes, two_shapes, data): +def test_operator_two_args_return_promoted(func, op, dtypes, two_shapes, data): (dtype1, dtype2), res_dtype = dtypes fillvalue1 = data.draw(hh.scalars(st.just(dtype1))) - if binary_op_name in ['>>', '<<']: + if op in ['>>', '<<']: fillvalue2 = data.draw(hh.scalars(st.just(dtype2)).filter(lambda x: x > 0)) else: fillvalue2 = data.draw(hh.scalars(st.just(dtype2))) @@ -267,23 +257,18 @@ def test_operator_two_args_return_promoted(binary_op_name, binary_op, dtypes, tw a2 = ah.full(shape2, fillvalue2, dtype=dtype2) get_locals = lambda: dict(a1=a1, a2=a2) - expression = f'a1 {binary_op} a2' + expression = f'a1 {op} a2' res = eval(expression, get_locals()) - assert res.dtype == res_dtype, f"{dtype1} {binary_op} {dtype2} promoted to {res.dtype}, should have promoted to {res_dtype} (shape={shape1, shape2})" - -operator_inplace_two_args_promoted_parametrize_inputs = [(binary_op, dtypes) for binary_op, dtypes in operator_two_args_promoted_parametrize_inputs - if dtypes[0][0] == dtypes[1]] -operator_inplace_two_args_promoted_parametrize_ids = ['-'.join((n[:2] + 'i' + n[2:], str(d1), str(d2))) for n, ((d1, d2), _) - in operator_inplace_two_args_promoted_parametrize_inputs] + assert res.dtype == res_dtype, f"{dtype1} {op} {dtype2} promoted to {res.dtype}, should have promoted to {res_dtype} (shape={shape1, shape2})" -@pytest.mark.parametrize('binary_op_name, binary_op, dtypes', generate_params('operator', in_nargs=2, out_category='promoted')) +@pytest.mark.parametrize('func, op, dtypes', generate_params('operator', in_nargs=2, out_category='promoted')) @given(two_shapes=hh.two_broadcastable_shapes(), data=st.data()) -def test_operator_inplace_two_args_return_promoted(binary_op_name, binary_op, dtypes, two_shapes, +def test_operator_inplace_two_args_return_promoted(func, op, dtypes, two_shapes, data): (dtype1, dtype2), res_dtype = dtypes fillvalue1 = data.draw(hh.scalars(st.just(dtype1))) - if binary_op_name in ['>>', '<<']: + if func in ['>>', '<<']: fillvalue2 = data.draw(hh.scalars(st.just(dtype2)).filter(lambda x: x > 0)) else: fillvalue2 = data.draw(hh.scalars(st.just(dtype2))) @@ -299,29 +284,29 @@ def test_operator_inplace_two_args_return_promoted(binary_op_name, binary_op, dt get_locals = lambda: dict(a1=a1, a2=a2) res_locals = get_locals() - expression = f'a1 {binary_op}= a2' + expression = f'a1 {op}= a2' exec(expression, res_locals) res = res_locals['a1'] - assert res.dtype == res_dtype, f"{dtype1} {binary_op}= {dtype2} promoted to {res.dtype}, should have promoted to {res_dtype} (shape={shape1, shape2})" + assert res.dtype == res_dtype, f"{dtype1} {op}= {dtype2} promoted to {res.dtype}, should have promoted to {res_dtype} (shape={shape1, shape2})" scalar_promotion_parametrize_inputs = [ - pytest.param(binary_op_name, dtype, scalar_type, id=f"{binary_op_name}-{dtype}-{scalar_type.__name__}") - for binary_op_name in sorted(set(dh.binary_op_to_symbol) - {'__matmul__'}) - for dtype in dh.category_to_dtypes[dh.func_in_categories[dh.op_to_func[binary_op_name]]] + pytest.param(func, dtype, scalar_type, id=f"{func}-{dtype}-{scalar_type.__name__}") + for func in sorted(set(dh.binary_func_to_op) - {'__matmul__'}) + for dtype in dh.category_to_dtypes[dh.func_in_categories[func]] for scalar_type in dh.dtypes_to_scalars[dtype] ] -@pytest.mark.parametrize('binary_op_name,dtype,scalar_type', +@pytest.mark.parametrize('func,dtype,scalar_type', scalar_promotion_parametrize_inputs) @given(shape=hh.shapes, python_scalars=st.data(), data=st.data()) -def test_operator_scalar_arg_return_promoted(binary_op_name, dtype, scalar_type, +def test_operator_scalar_arg_return_promoted(func, dtype, scalar_type, shape, python_scalars, data): """ See https://st.data-apis.github.io/array-api/latest/API_specification/type_promotion.html#mixing-arrays-with-python-hh.scalars """ - binary_op = dh.binary_op_to_symbol[binary_op_name] - if binary_op == '@': + op = dh.binary_func_to_op[func] + if op == '@': pytest.skip("matmul (@) is not supported for hh.scalars") if dtype in dh.category_to_dtypes['integer']: @@ -344,23 +329,23 @@ def test_operator_scalar_arg_return_promoted(binary_op_name, dtype, scalar_type, # 2. Execute the operation for `array 0-D array` (or `0-D array # array` if `scalar` was the left-hand argument). - array_scalar = f'a {binary_op} s' - array_scalar_expected = f'a {binary_op} scalar_as_array' + array_scalar = f'a {op} s' + array_scalar_expected = f'a {op} scalar_as_array' res = eval(array_scalar, get_locals()) expected = eval(array_scalar_expected, get_locals()) ah.assert_exactly_equal(res, expected) - scalar_array = f's {binary_op} a' - scalar_array_expected = f'scalar_as_array {binary_op} a' + scalar_array = f's {op} a' + scalar_array_expected = f'scalar_as_array {op} a' res = eval(scalar_array, get_locals()) expected = eval(scalar_array_expected, get_locals()) ah.assert_exactly_equal(res, expected) # Test in-place operators - if binary_op in ['==', '!=', '<', '>', '<=', '>=']: + if op in ['==', '!=', '<', '>', '<=', '>=']: return - array_scalar = f'a {binary_op}= s' - array_scalar_expected = f'a {binary_op}= scalar_as_array' + array_scalar = f'a {op}= s' + array_scalar_expected = f'a {op}= scalar_as_array' a = ah.full(shape, fillvalue, dtype=dtype) res_locals = get_locals() exec(array_scalar, get_locals()) From dd3fb86c392a1192ff9682722e7da25296d3404b Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Tue, 12 Oct 2021 14:58:23 +0100 Subject: [PATCH 14/41] Refactor type promotion tests --- array_api_tests/dtype_helpers.py | 29 +- array_api_tests/hypothesis_helpers.py | 15 +- array_api_tests/test_type_promotion.py | 375 +++++++++---------------- 3 files changed, 167 insertions(+), 252 deletions(-) diff --git a/array_api_tests/dtype_helpers.py b/array_api_tests/dtype_helpers.py index f35b435c..287a99ba 100644 --- a/array_api_tests/dtype_helpers.py +++ b/array_api_tests/dtype_helpers.py @@ -9,8 +9,10 @@ 'dtype_signed', 'func_in_categories', 'func_out_categories', - 'binary_func_to_op', - 'unary_func_to_op', + 'op_in_categories', + 'op_out_categories', + 'binary_op_to_symbol', + 'unary_op_to_symbol', ] @@ -222,7 +224,14 @@ } -binary_func_to_op = { +unary_op_to_symbol = { + '__invert__': '~', + '__neg__': '-', + '__pos__': '+', +} + + +binary_op_to_symbol = { '__add__': '+', '__and__': '&', '__eq__': '==', @@ -245,14 +254,6 @@ } -unary_func_to_op = { - '__abs__': 'abs()', - '__invert__': '~', - '__neg__': '-', - '__pos__': '+', -} - - _operator_to_elementwise = { '__abs__': 'abs', '__add__': 'add', @@ -280,6 +281,8 @@ } +op_in_categories = {} +op_out_categories = {} for op_func, elwise_func in _operator_to_elementwise.items(): - func_in_categories[op_func] = func_in_categories[elwise_func] - func_out_categories[op_func] = func_out_categories[elwise_func] + op_in_categories[op_func] = func_in_categories[elwise_func] + op_out_categories[op_func] = func_out_categories[elwise_func] diff --git a/array_api_tests/hypothesis_helpers.py b/array_api_tests/hypothesis_helpers.py index 639421fb..c853a871 100644 --- a/array_api_tests/hypothesis_helpers.py +++ b/array_api_tests/hypothesis_helpers.py @@ -2,11 +2,13 @@ from operator import mul from math import sqrt import itertools +from typing import Tuple from hypothesis import assume from hypothesis.strategies import (lists, integers, sampled_from, shared, floats, just, composite, one_of, none, booleans) +from hypothesis.strategies._internal.strategies import SearchStrategy from .pytest_helpers import nargs from .array_helpers import (dtype_ranges, integer_dtype_objects, @@ -114,9 +116,16 @@ def tuples(elements, *, min_size=0, max_size=None, unique_by=None, unique=False) square_matrix_shapes = matrix_shapes.filter(lambda shape: shape[-1] == shape[-2]) -two_mutually_broadcastable_shapes = xps.mutually_broadcastable_shapes(num_shapes=2)\ - .map(lambda S: S.input_shapes)\ - .filter(lambda S: all(prod(i for i in shape if i) < MAX_ARRAY_SIZE for shape in S)) +def mutually_broadcastable_shapes(num_shapes: int) -> SearchStrategy[Tuple[Tuple]]: + return ( + xps.mutually_broadcastable_shapes(num_shapes) + .map(lambda BS: BS.input_shapes) + .filter(lambda shapes: all( + prod(i for i in s if i > 0) < MAX_ARRAY_SIZE for s in shapes + )) + ) + +two_mutually_broadcastable_shapes = mutually_broadcastable_shapes(2) # Note: This should become hermitian_matrices when complex dtypes are added @composite diff --git a/array_api_tests/test_type_promotion.py b/array_api_tests/test_type_promotion.py index 7fd37ab6..8a899e44 100644 --- a/array_api_tests/test_type_promotion.py +++ b/array_api_tests/test_type_promotion.py @@ -6,13 +6,14 @@ from typing import Iterator, Literal import pytest -from hypothesis import given +from hypothesis import assume, given from hypothesis import strategies as st from . import _array_module as xp from . import array_helpers as ah from . import dtype_helpers as dh from . import hypothesis_helpers as hh +from . import xps from .function_stubs import elementwise_functions from .pytest_helpers import nargs @@ -42,19 +43,17 @@ def generate_params( yield pytest.param(func, ((d1, d2), d3), id=f"{func}({d1}, {d2}) -> {d3}") else: if in_nargs == 1: - for func, op in dh.unary_func_to_op.items(): - if func == "__matmul__": - continue - if dh.func_out_categories[func] == out_category: - in_category = dh.func_in_categories[func] + for func, op in dh.unary_op_to_symbol.items(): + if dh.op_out_categories[func] == out_category: + in_category = dh.op_in_categories[func] for in_dtype in dh.category_to_dtypes[in_category]: yield pytest.param(func, op, in_dtype, id=f"{func}({in_dtype})") else: - for func, op in dh.binary_func_to_op.items(): + for func, op in dh.binary_op_to_symbol.items(): if func == "__matmul__": continue - if dh.func_out_categories[func] == out_category: - in_category = dh.func_in_categories[func] + if dh.op_out_categories[func] == out_category: + in_category = dh.op_in_categories[func] for ((d1, d2), d3) in dh.promotion_table.items(): if all(d in dh.category_to_dtypes[in_category] for d in (d1, d2)): if out_category == 'bool': @@ -64,236 +63,140 @@ def generate_params( yield pytest.param(func, op, ((d1, d2), d3), id=f"{func}({d1}, {d2}) -> {d3}") +def generate_func_params() -> Iterator: + for func_name in elementwise_functions.__all__: + func = getattr(xp, func_name) + in_category = dh.func_in_categories[func_name] + out_category = dh.func_out_categories[func_name] + valid_in_dtypes = dh.category_to_dtypes[in_category] + ndtypes = nargs(func_name) + if ndtypes == 1: + for in_dtype in valid_in_dtypes: + out_dtype = in_dtype if out_category == 'promoted' else xp.bool + yield pytest.param( + func, (in_dtype,), out_dtype, id=f"{func_name}({in_dtype}) -> {out_dtype}" + ) + elif ndtypes == 2: + for (in_dtype1, in_dtype2), promoted_dtype in dh.promotion_table.items(): + if in_dtype1 in valid_in_dtypes and in_dtype2 in valid_in_dtypes: + out_dtype = promoted_dtype if out_category == 'promoted' else xp.bool + yield pytest.param( + func, (in_dtype1, in_dtype2), out_dtype, id=f'{func_name}({in_dtype1}, {in_dtype2}) -> {out_dtype}' + ) + else: + raise NotImplementedError() + + +@pytest.mark.parametrize('func, in_dtypes, out_dtype', generate_func_params()) +@given(data=st.data()) +def test_func_returns_array_with_correct_dtype(func, in_dtypes, out_dtype, data): + arrays = [] + shapes = data.draw(hh.mutually_broadcastable_shapes(len(in_dtypes)), label='shapes') + for i, (dtype, shape) in enumerate(zip(in_dtypes, shapes), 1): + x = data.draw(xps.arrays(dtype=dtype, shape=shape), label='x{i}') + arrays.append(x) + out = func(*arrays) + assert out.dtype == out_dtype, f"{out.dtype=!s}, but should be {out_dtype}" + + +def generate_unary_op_params() -> Iterator: + for op, symbol in dh.unary_op_to_symbol.items(): + if op == '__abs__': + continue + in_category = dh.op_in_categories[op] + out_category = dh.op_out_categories[op] + valid_in_dtypes = dh.category_to_dtypes[in_category] + for in_dtype in valid_in_dtypes: + out_dtype = in_dtype if out_category == 'promoted' else xp.bool + yield pytest.param(symbol, in_dtype, out_dtype, id=f'{op}({in_dtype}) -> {out_dtype}') + + +@pytest.mark.parametrize('op_symbol, in_dtype, out_dtype', generate_unary_op_params()) +@given(data=st.data()) +def test_unary_operator_returns_array_with_correct_dtype(op_symbol, in_dtype, out_dtype, data): + x = data.draw(xps.arrays(dtype=in_dtype, shape=hh.shapes), label='x') + out = eval(f'{op_symbol}x', {"x": x}) + assert out.dtype == out_dtype, f"{out.dtype=!s}, but should be {out_dtype}" + + +def generate_abs_op_params() -> Iterator: + in_category = dh.op_in_categories['__abs__'] + out_category = dh.op_out_categories['__abs__'] + valid_in_dtypes = dh.category_to_dtypes[in_category] + for in_dtype in valid_in_dtypes: + out_dtype = in_dtype if out_category == 'promoted' else xp.bool + yield pytest.param(in_dtype, out_dtype, id=f'__abs__({in_dtype}) -> {out_dtype}') + + +@pytest.mark.parametrize('in_dtype, out_dtype', generate_abs_op_params()) +@given(data=st.data()) +def test_abs_operator_returns_array_with_correct_dtype(in_dtype, out_dtype, data): + x = data.draw(xps.arrays(dtype=in_dtype, shape=hh.shapes), label='x') + out = eval('abs(x)', {"x": x}) + assert out.dtype == out_dtype, f"{out.dtype=!s}, but should be {out_dtype}" + + +def generate_binary_op_params() -> Iterator: + for op, symbol in dh.binary_op_to_symbol.items(): + if op == '__matmul__' or 'shift' in op: + continue + in_category = dh.op_in_categories[op] + out_category = dh.op_out_categories[op] + valid_in_dtypes = dh.category_to_dtypes[in_category] + for (in_dtype1, in_dtype2), promoted_dtype in dh.promotion_table.items(): + if in_dtype1 in valid_in_dtypes and in_dtype2 in valid_in_dtypes: + out_dtype = promoted_dtype if out_category == 'promoted' else xp.bool + yield pytest.param( + symbol, + (in_dtype1, in_dtype2), + out_dtype, + id=f'{op}({in_dtype1}, {in_dtype2}) -> {out_dtype}' + ) + + +@pytest.mark.parametrize('op_symbol, in_dtypes, out_dtype', generate_binary_op_params()) +@given(shapes=hh.mutually_broadcastable_shapes(2), data=st.data()) +def test_binary_operator_returns_array_with_correct_dtype(op_symbol, in_dtypes, out_dtype, shapes, data): + x1 = data.draw(xps.arrays(dtype=in_dtypes[0], shape=shapes[0]), label='x1') + x2 = data.draw(xps.arrays(dtype=in_dtypes[1], shape=shapes[1]), label='x2') + out = eval(f'x1 {op_symbol} x2', {"x1": x1, "x2": x2}) + assert out.dtype == out_dtype, f"{out.dtype=!s}, but should be {out_dtype}" + + +def generate_inplace_op_params() -> Iterator: + for op, symbol in dh.binary_op_to_symbol.items(): + if op == '__matmul__' or 'shift' in op or '=' in symbol or '<' in symbol or '>' in symbol: + continue + in_category = dh.op_in_categories[op] + out_category = dh.op_out_categories[op] + valid_in_dtypes = dh.category_to_dtypes[in_category] + for (in_dtype1, in_dtype2), promoted_dtype in dh.promotion_table.items(): + if in_dtype1 == promoted_dtype and in_dtype1 in valid_in_dtypes and in_dtype2 in valid_in_dtypes: + out_dtype = promoted_dtype if out_category == 'promoted' else xp.bool + yield pytest.param( + f'{symbol}=', + (in_dtype1, in_dtype2), + out_dtype, + id=f'__i{op[2:]}({in_dtype1}, {in_dtype2}) -> {out_dtype}' + ) + + +@pytest.mark.parametrize('op_symbol, in_dtypes, out_dtype', generate_inplace_op_params()) +@given(shapes=hh.mutually_broadcastable_shapes(2), data=st.data()) +def test_inplace_operator_returns_array_with_correct_dtype(op_symbol, in_dtypes, out_dtype, shapes, data): + assume(len(shapes[0]) >= len(shapes[1])) + x1 = data.draw(xps.arrays(dtype=in_dtypes[0], shape=shapes[0]), label='x1') + x2 = data.draw(xps.arrays(dtype=in_dtypes[1], shape=shapes[1]), label='x2') + locals_ = {"x1": x1, "x2": x2} + exec(f'x1 {op_symbol} x2', locals_) + x1 = locals_["x1"] + assert x1.dtype == out_dtype, f"{x1.dtype=!s}, but should be {out_dtype}" -# TODO: These functions should still do type promotion internally, but we do -# not test this here (it is tested in the corresponding tests in -# test_elementwise_functions.py). This can affect the resulting values if not -# done correctly. For example, greater_equal(array(1.0, dtype=float32), -# array(1.00000001, dtype=float64)) will be wrong if the float64 array is -# downcast to float32. See for instance -# https://github.com/numpy/numpy/issues/10322. -@pytest.mark.parametrize('func, dtypes', generate_params('elementwise', in_nargs=2, out_category='bool')) -# The spec explicitly requires type promotion to work for shape 0 -# Unfortunately, data(), isn't compatible with @example, so this is commented -# out for now. -# @example(shape=(0,)) -@given(two_shapes=hh.two_mutually_broadcastable_shapes, data=st.data()) -def test_elementwise_two_args_return_bool(func, two_shapes, dtypes, data): - assert nargs(func) == 2 - func = getattr(xp, func) - - dtype1, dtype2 = dtypes - - fillvalue1 = data.draw(hh.scalars(st.just(dtype1))) - if func in ['bitwise_left_shift', 'bitwise_right_shift']: - fillvalue2 = data.draw(hh.scalars(st.just(dtype2)).filter(lambda x: x > 0)) - else: - fillvalue2 = data.draw(hh.scalars(st.just(dtype2))) - - - for i in [func, dtype1, dtype2]: - if isinstance(i, xp._UndefinedStub): - i._raise() - - shape1, shape2 = two_shapes - a1 = ah.full(shape1, fillvalue1, dtype=dtype1) - a2 = ah.full(shape2, fillvalue2, dtype=dtype2) - res = func(a1, a2) - - assert res.dtype == xp.bool, f"{func}({dtype1}, {dtype2}) promoted to {res.dtype}, should have promoted to bool (shapes={shape1, shape2})" - -# TODO: Extend this to all functions (not just elementwise), and handle -# functions that take more than 2 args -@pytest.mark.parametrize('func, dtypes', generate_params('elementwise', in_nargs=2, out_category='promoted')) -# The spec explicitly requires type promotion to work for shape 0 -# Unfortunately, data(), isn't compatible with @example, so this is commented -# out for now. -# @example(shape=(0,)) -@given(two_shapes=hh.two_mutually_broadcastable_shapes, data=st.data()) -def test_elementwise_two_args_return_promoted(func, - two_shapes, dtypes, - data): - assert nargs(func) == 2 - func = getattr(xp, func) - - (dtype1, dtype2), res_dtype = dtypes - fillvalue1 = data.draw(hh.scalars(st.just(dtype1))) - if func in ['bitwise_left_shift', 'bitwise_right_shift']: - fillvalue2 = data.draw(hh.scalars(st.just(dtype2)).filter(lambda x: x > 0)) - else: - fillvalue2 = data.draw(hh.scalars(st.just(dtype2))) - - - for i in [func, dtype1, dtype2, res_dtype]: - if isinstance(i, xp._UndefinedStub): - i._raise() - - shape1, shape2 = two_shapes - a1 = ah.full(shape1, fillvalue1, dtype=dtype1) - a2 = ah.full(shape2, fillvalue2, dtype=dtype2) - res = func(a1, a2) - - assert res.dtype == res_dtype, f"{func}({dtype1}, {dtype2}) promoted to {res.dtype}, should have promoted to {res_dtype} (shapes={shape1, shape2})" - -# TODO: Extend this to all functions (not just elementwise), and handle -# functions that take more than 2 args -@pytest.mark.parametrize('func, dtype', generate_params('elementwise', in_nargs=1, out_category='bool')) -# The spec explicitly requires type promotion to work for shape 0 -# Unfortunately, data(), isn't compatible with @example, so this is commented -# out for now. -# @example(shape=(0,)) -@given(shape=hh.shapes, data=st.data()) -def test_elementwise_one_arg_return_bool(func, shape, dtype, data): - assert nargs(func) == 1 - func = getattr(xp, func) - - fillvalue = data.draw(hh.scalars(st.just(dtype))) - - for i in [func, dtype]: - if isinstance(i, xp._UndefinedStub): - i._raise() - - x = ah.full(shape, fillvalue, dtype=dtype) - res = func(x) - - assert res.dtype == xp.bool, f"{func}({dtype}) returned to {res.dtype}, should have promoted to bool (shape={shape})" - -# TODO: Extend this to all functions (not just elementwise), and handle -# functions that take more than 2 args -@pytest.mark.parametrize('func,dtype', generate_params('elementwise', in_nargs=1, out_category='promoted')) -# The spec explicitly requires type promotion to work for shape 0 -# Unfortunately, data(), isn't compatible with @example, so this is commented -# out for now. -# @example(shape=(0,)) -@given(shape=hh.shapes, data=st.data()) -def test_elementwise_one_arg_return_promoted(func, shape, - dtype, data): - assert nargs(func) == 1 - func = getattr(xp, func) - - fillvalue = data.draw(hh.scalars(st.just(dtype))) - - for i in [func, dtype]: - if isinstance(i, xp._UndefinedStub): - i._raise() - - x = ah.full(shape, fillvalue, dtype=dtype) - res = func(x) - - assert res.dtype == dtype, f"{func}({dtype}) returned to {res.dtype}, should have promoted to {dtype} (shape={shape})" - - -# TODO: Extend this to all functions (not just elementwise), and handle -# functions that take more than 2 args -@pytest.mark.parametrize( - 'unary_op_name, unary_op, dtype', - generate_params('operator', in_nargs=1, out_category='promoted'), -) -# The spec explicitly requires type promotion to work for shape 0 -# Unfortunately, data(), isn't compatible with @example, so this is commented -# out for now. -# @example(shape=(0,)) -@given(shape=hh.shapes, data=st.data()) -def test_operator_one_arg_return_promoted(unary_op_name, unary_op, shape, dtype, data): - fillvalue = data.draw(hh.scalars(st.just(dtype))) - - if isinstance(dtype, xp._UndefinedStub): - dtype._raise() - - a = ah.full(shape, fillvalue, dtype=dtype) - - get_locals = lambda: dict(a=a) - - if unary_op_name == '__abs__': - # This is the built-in abs(), not the array module's abs() - expression = 'abs(a)' - else: - expression = f'{unary_op} a' - res = eval(expression, get_locals()) - - assert res.dtype == dtype, f"{unary_op}({dtype}) returned to {res.dtype}, should have promoted to {dtype} (shape={shape})" - -@pytest.mark.parametrize( - 'func, op, dtypes', - generate_params('operator', in_nargs=2, out_category='bool') -) -@given(two_shapes=hh.two_mutually_broadcastable_shapes, data=st.data()) -def test_operator_two_args_return_bool(func, op, dtypes, two_shapes, data): - dtype1, dtype2 = dtypes - fillvalue1 = data.draw(hh.scalars(st.just(dtype1))) - fillvalue2 = data.draw(hh.scalars(st.just(dtype2))) - - for i in [dtype1, dtype2]: - if isinstance(i, xp._UndefinedStub): - i._raise() - - shape1, shape2 = two_shapes - a1 = ah.full(shape1, fillvalue1, dtype=dtype1) - a2 = ah.full(shape2, fillvalue2, dtype=dtype2) - - get_locals = lambda: dict(a1=a1, a2=a2) - expression = f'a1 {op} a2' - res = eval(expression, get_locals()) - - assert res.dtype == xp.bool, f"{dtype1} {op} {dtype2} promoted to {res.dtype}, should have promoted to bool (shape={shape1, shape2})" - -@pytest.mark.parametrize('func, op, dtypes', generate_params('operator', in_nargs=2, out_category='promoted')) -@given(two_shapes=hh.two_mutually_broadcastable_shapes, data=st.data()) -def test_operator_two_args_return_promoted(func, op, dtypes, two_shapes, data): - (dtype1, dtype2), res_dtype = dtypes - fillvalue1 = data.draw(hh.scalars(st.just(dtype1))) - if op in ['>>', '<<']: - fillvalue2 = data.draw(hh.scalars(st.just(dtype2)).filter(lambda x: x > 0)) - else: - fillvalue2 = data.draw(hh.scalars(st.just(dtype2))) - - - for i in [dtype1, dtype2, res_dtype]: - if isinstance(i, xp._UndefinedStub): - i._raise() - - shape1, shape2 = two_shapes - a1 = ah.full(shape1, fillvalue1, dtype=dtype1) - a2 = ah.full(shape2, fillvalue2, dtype=dtype2) - - get_locals = lambda: dict(a1=a1, a2=a2) - expression = f'a1 {op} a2' - res = eval(expression, get_locals()) - - assert res.dtype == res_dtype, f"{dtype1} {op} {dtype2} promoted to {res.dtype}, should have promoted to {res_dtype} (shape={shape1, shape2})" - -@pytest.mark.parametrize('func, op, dtypes', generate_params('operator', in_nargs=2, out_category='promoted')) -@given(two_shapes=hh.two_broadcastable_shapes(), data=st.data()) -def test_operator_inplace_two_args_return_promoted(func, op, dtypes, two_shapes, - data): - (dtype1, dtype2), res_dtype = dtypes - fillvalue1 = data.draw(hh.scalars(st.just(dtype1))) - if func in ['>>', '<<']: - fillvalue2 = data.draw(hh.scalars(st.just(dtype2)).filter(lambda x: x > 0)) - else: - fillvalue2 = data.draw(hh.scalars(st.just(dtype2))) - - for i in [dtype1, dtype2, res_dtype]: - if isinstance(i, xp._UndefinedStub): - i._raise() - - shape1, shape2 = two_shapes - a1 = ah.full(shape1, fillvalue1, dtype=dtype1) - a2 = ah.full(shape2, fillvalue2, dtype=dtype2) - - get_locals = lambda: dict(a1=a1, a2=a2) - - res_locals = get_locals() - expression = f'a1 {op}= a2' - exec(expression, res_locals) - res = res_locals['a1'] - - assert res.dtype == res_dtype, f"{dtype1} {op}= {dtype2} promoted to {res.dtype}, should have promoted to {res_dtype} (shape={shape1, shape2})" scalar_promotion_parametrize_inputs = [ pytest.param(func, dtype, scalar_type, id=f"{func}-{dtype}-{scalar_type.__name__}") - for func in sorted(set(dh.binary_func_to_op) - {'__matmul__'}) - for dtype in dh.category_to_dtypes[dh.func_in_categories[func]] + for func in sorted(set(dh.binary_op_to_symbol) - {'__matmul__'}) + for dtype in dh.category_to_dtypes[dh.op_in_categories[func]] for scalar_type in dh.dtypes_to_scalars[dtype] ] @@ -305,7 +208,7 @@ def test_operator_scalar_arg_return_promoted(func, dtype, scalar_type, """ See https://st.data-apis.github.io/array-api/latest/API_specification/type_promotion.html#mixing-arrays-with-python-hh.scalars """ - op = dh.binary_func_to_op[func] + op = dh.binary_op_to_symbol[func] if op == '@': pytest.skip("matmul (@) is not supported for hh.scalars") From 043726da7de9ad8e6631ba1c18ce1f1c6506faeb Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Tue, 12 Oct 2021 17:02:11 +0100 Subject: [PATCH 15/41] Merge unary and binary tests --- array_api_tests/test_type_promotion.py | 279 +++++++++++++------------ 1 file changed, 143 insertions(+), 136 deletions(-) diff --git a/array_api_tests/test_type_promotion.py b/array_api_tests/test_type_promotion.py index 8a899e44..81b620b7 100644 --- a/array_api_tests/test_type_promotion.py +++ b/array_api_tests/test_type_promotion.py @@ -1,9 +1,8 @@ """ https://data-apis.github.io/array-api/latest/API_specification/type_promotion.html """ - -from itertools import product -from typing import Iterator, Literal +import re +from typing import Iterator import pytest from hypothesis import assume, given @@ -17,50 +16,21 @@ from .function_stubs import elementwise_functions from .pytest_helpers import nargs +# We apply filters to xps.arrays() when behaviour for certain array elements is +# erroneous or undefined. + + +filters = { + re.compile('bitwise_[a-z]+_shift'): lambda x: ah.all(x > 0), + re.compile('__[rl]shift__'): lambda x: ah.all(x > 0), +} -# Note: the boolean binary operators do not have reversed or in-place variants -def generate_params( - func_family: Literal['elementwise', 'operator'], - in_nargs: int, - out_category: Literal['bool', 'promoted'], -) -> Iterator: - if func_family == 'elementwise': - funcs = [ - f for f in elementwise_functions.__all__ - if nargs(f) == in_nargs and dh.func_out_categories[f] == out_category - ] - if in_nargs == 1: - for func in funcs: - in_category = dh.func_in_categories[func] - for in_dtype in dh.category_to_dtypes[in_category]: - yield pytest.param(func, in_dtype, id=f"{func}({in_dtype})") - else: - for func, ((d1, d2), d3) in product(funcs, dh.promotion_table.items()): - if all(d in dh.category_to_dtypes[dh.func_in_categories[func]] for d in (d1, d2)): - if out_category == 'bool': - yield pytest.param(func, (d1, d2), id=f"{func}({d1}, {d2})") - else: - yield pytest.param(func, ((d1, d2), d3), id=f"{func}({d1}, {d2}) -> {d3}") - else: - if in_nargs == 1: - for func, op in dh.unary_op_to_symbol.items(): - if dh.op_out_categories[func] == out_category: - in_category = dh.op_in_categories[func] - for in_dtype in dh.category_to_dtypes[in_category]: - yield pytest.param(func, op, in_dtype, id=f"{func}({in_dtype})") - else: - for func, op in dh.binary_op_to_symbol.items(): - if func == "__matmul__": - continue - if dh.op_out_categories[func] == out_category: - in_category = dh.op_in_categories[func] - for ((d1, d2), d3) in dh.promotion_table.items(): - if all(d in dh.category_to_dtypes[in_category] for d in (d1, d2)): - if out_category == 'bool': - yield pytest.param(func, op, (d1, d2), id=f"{func}({d1}, {d2})") - else: - if d1 == d3: - yield pytest.param(func, op, ((d1, d2), d3), id=f"{func}({d1}, {d2}) -> {d3}") + +def get_filter(name): + for regex, cond in filters.items(): + if regex.match(name): + return cond + return lambda _: True def generate_func_params() -> Iterator: @@ -74,143 +44,180 @@ def generate_func_params() -> Iterator: for in_dtype in valid_in_dtypes: out_dtype = in_dtype if out_category == 'promoted' else xp.bool yield pytest.param( - func, (in_dtype,), out_dtype, id=f"{func_name}({in_dtype}) -> {out_dtype}" + func, + (in_dtype,), + out_dtype, + get_filter(func_name), + id=f'{func_name}({in_dtype}) -> {out_dtype}', ) elif ndtypes == 2: for (in_dtype1, in_dtype2), promoted_dtype in dh.promotion_table.items(): if in_dtype1 in valid_in_dtypes and in_dtype2 in valid_in_dtypes: - out_dtype = promoted_dtype if out_category == 'promoted' else xp.bool + out_dtype = ( + promoted_dtype if out_category == 'promoted' else xp.bool + ) yield pytest.param( - func, (in_dtype1, in_dtype2), out_dtype, id=f'{func_name}({in_dtype1}, {in_dtype2}) -> {out_dtype}' + func, + (in_dtype1, in_dtype2), + out_dtype, + get_filter(func_name), + id=f'{func_name}({in_dtype1}, {in_dtype2}) -> {out_dtype}', ) else: raise NotImplementedError() -@pytest.mark.parametrize('func, in_dtypes, out_dtype', generate_func_params()) +@pytest.mark.parametrize( + 'func, in_dtypes, out_dtype, arrays_filter', generate_func_params() +) @given(data=st.data()) -def test_func_returns_array_with_correct_dtype(func, in_dtypes, out_dtype, data): - arrays = [] - shapes = data.draw(hh.mutually_broadcastable_shapes(len(in_dtypes)), label='shapes') - for i, (dtype, shape) in enumerate(zip(in_dtypes, shapes), 1): - x = data.draw(xps.arrays(dtype=dtype, shape=shape), label='x{i}') - arrays.append(x) - out = func(*arrays) - assert out.dtype == out_dtype, f"{out.dtype=!s}, but should be {out_dtype}" - - -def generate_unary_op_params() -> Iterator: - for op, symbol in dh.unary_op_to_symbol.items(): - if op == '__abs__': - continue - in_category = dh.op_in_categories[op] - out_category = dh.op_out_categories[op] - valid_in_dtypes = dh.category_to_dtypes[in_category] - for in_dtype in valid_in_dtypes: - out_dtype = in_dtype if out_category == 'promoted' else xp.bool - yield pytest.param(symbol, in_dtype, out_dtype, id=f'{op}({in_dtype}) -> {out_dtype}') - - -@pytest.mark.parametrize('op_symbol, in_dtype, out_dtype', generate_unary_op_params()) -@given(data=st.data()) -def test_unary_operator_returns_array_with_correct_dtype(op_symbol, in_dtype, out_dtype, data): - x = data.draw(xps.arrays(dtype=in_dtype, shape=hh.shapes), label='x') - out = eval(f'{op_symbol}x', {"x": x}) - assert out.dtype == out_dtype, f"{out.dtype=!s}, but should be {out_dtype}" - - -def generate_abs_op_params() -> Iterator: - in_category = dh.op_in_categories['__abs__'] - out_category = dh.op_out_categories['__abs__'] - valid_in_dtypes = dh.category_to_dtypes[in_category] - for in_dtype in valid_in_dtypes: - out_dtype = in_dtype if out_category == 'promoted' else xp.bool - yield pytest.param(in_dtype, out_dtype, id=f'__abs__({in_dtype}) -> {out_dtype}') - - -@pytest.mark.parametrize('in_dtype, out_dtype', generate_abs_op_params()) -@given(data=st.data()) -def test_abs_operator_returns_array_with_correct_dtype(in_dtype, out_dtype, data): - x = data.draw(xps.arrays(dtype=in_dtype, shape=hh.shapes), label='x') - out = eval('abs(x)', {"x": x}) - assert out.dtype == out_dtype, f"{out.dtype=!s}, but should be {out_dtype}" - - -def generate_binary_op_params() -> Iterator: - for op, symbol in dh.binary_op_to_symbol.items(): - if op == '__matmul__' or 'shift' in op: +def test_func_returns_array_with_correct_dtype( + func, in_dtypes, out_dtype, arrays_filter, data +): + if len(in_dtypes) == 1: + x = data.draw( + xps.arrays(dtype=in_dtypes[0], shape=hh.shapes).filter(arrays_filter), + label='x', + ) + out = func(x) + else: + arrays = [] + shapes = data.draw( + hh.mutually_broadcastable_shapes(len(in_dtypes)), label='shapes' + ) + for i, (dtype, shape) in enumerate(zip(in_dtypes, shapes), 1): + x = data.draw( + xps.arrays(dtype=dtype, shape=shape).filter(arrays_filter), + label=f'x{i}', + ) + arrays.append(x) + out = func(*arrays) + assert out.dtype == out_dtype, f'{out.dtype=!s}, but should be {out_dtype}' + + +def generate_op_params() -> Iterator: + op_to_symbol = {**dh.unary_op_to_symbol, **dh.binary_op_to_symbol} + for op, symbol in op_to_symbol.items(): + if op == '__matmul__': continue in_category = dh.op_in_categories[op] out_category = dh.op_out_categories[op] valid_in_dtypes = dh.category_to_dtypes[in_category] - for (in_dtype1, in_dtype2), promoted_dtype in dh.promotion_table.items(): - if in_dtype1 in valid_in_dtypes and in_dtype2 in valid_in_dtypes: - out_dtype = promoted_dtype if out_category == 'promoted' else xp.bool + ndtypes = nargs(op) + if ndtypes == 1: + for in_dtype in valid_in_dtypes: + out_dtype = in_dtype if out_category == 'promoted' else xp.bool yield pytest.param( - symbol, - (in_dtype1, in_dtype2), + f'{symbol}x', + (in_dtype,), out_dtype, - id=f'{op}({in_dtype1}, {in_dtype2}) -> {out_dtype}' + get_filter(op), + id=f'{op}({in_dtype}) -> {out_dtype}', ) - - -@pytest.mark.parametrize('op_symbol, in_dtypes, out_dtype', generate_binary_op_params()) -@given(shapes=hh.mutually_broadcastable_shapes(2), data=st.data()) -def test_binary_operator_returns_array_with_correct_dtype(op_symbol, in_dtypes, out_dtype, shapes, data): - x1 = data.draw(xps.arrays(dtype=in_dtypes[0], shape=shapes[0]), label='x1') - x2 = data.draw(xps.arrays(dtype=in_dtypes[1], shape=shapes[1]), label='x2') - out = eval(f'x1 {op_symbol} x2', {"x1": x1, "x2": x2}) - assert out.dtype == out_dtype, f"{out.dtype=!s}, but should be {out_dtype}" + else: + for (in_dtype1, in_dtype2), promoted_dtype in dh.promotion_table.items(): + if in_dtype1 in valid_in_dtypes and in_dtype2 in valid_in_dtypes: + out_dtype = ( + promoted_dtype if out_category == 'promoted' else xp.bool + ) + yield pytest.param( + f'x1 {symbol} x2', + (in_dtype1, in_dtype2), + out_dtype, + get_filter(op), + id=f'{op}({in_dtype1}, {in_dtype2}) -> {out_dtype}', + ) + # We generate params for abs seperately as it does not have an associated symbol + for in_dtype in dh.category_to_dtypes[dh.op_in_categories['__abs__']]: + yield pytest.param( + 'abs(x)', + (in_dtype,), + in_dtype, + get_filter('__abs__'), + id=f'__abs__({in_dtype}) -> {in_dtype}', + ) + + +@pytest.mark.parametrize( + 'expr, in_dtypes, out_dtype, arrays_filter', generate_op_params() +) +@given(data=st.data()) +def test_operator_returns_array_with_correct_dtype( + expr, in_dtypes, out_dtype, arrays_filter, data +): + if len(in_dtypes) == 1: + x = data.draw( + xps.arrays(dtype=in_dtypes[0], shape=hh.shapes).filter(arrays_filter), + label='x', + ) + out = eval(expr, {'x': x}) + else: + locals_ = {} + shapes = data.draw( + hh.mutually_broadcastable_shapes(len(in_dtypes)), label='shapes' + ) + for i, (dtype, shape) in enumerate(zip(in_dtypes, shapes), 1): + locals_[f'x{i}'] = data.draw( + xps.arrays(dtype=dtype, shape=shape).filter(arrays_filter), + label=f'x{i}', + ) + out = eval(expr, locals_) + assert out.dtype == out_dtype, f'{out.dtype=!s}, but should be {out_dtype}' def generate_inplace_op_params() -> Iterator: for op, symbol in dh.binary_op_to_symbol.items(): - if op == '__matmul__' or 'shift' in op or '=' in symbol or '<' in symbol or '>' in symbol: + if op == '__matmul__' or dh.op_out_categories[op] == 'bool': continue in_category = dh.op_in_categories[op] - out_category = dh.op_out_categories[op] valid_in_dtypes = dh.category_to_dtypes[in_category] for (in_dtype1, in_dtype2), promoted_dtype in dh.promotion_table.items(): - if in_dtype1 == promoted_dtype and in_dtype1 in valid_in_dtypes and in_dtype2 in valid_in_dtypes: - out_dtype = promoted_dtype if out_category == 'promoted' else xp.bool + if ( + in_dtype1 == promoted_dtype + and in_dtype1 in valid_in_dtypes + and in_dtype2 in valid_in_dtypes + ): yield pytest.param( - f'{symbol}=', + f'x1 {symbol}= x2', (in_dtype1, in_dtype2), - out_dtype, - id=f'__i{op[2:]}({in_dtype1}, {in_dtype2}) -> {out_dtype}' + promoted_dtype, + id=f'__i{op[2:]}({in_dtype1}, {in_dtype2}) -> {in_dtype1}', ) -@pytest.mark.parametrize('op_symbol, in_dtypes, out_dtype', generate_inplace_op_params()) +@pytest.mark.parametrize('expr, in_dtypes, out_dtype', generate_inplace_op_params()) @given(shapes=hh.mutually_broadcastable_shapes(2), data=st.data()) -def test_inplace_operator_returns_array_with_correct_dtype(op_symbol, in_dtypes, out_dtype, shapes, data): +def test_inplace_operator_returns_array_with_correct_dtype( + expr, in_dtypes, out_dtype, shapes, data +): assume(len(shapes[0]) >= len(shapes[1])) x1 = data.draw(xps.arrays(dtype=in_dtypes[0], shape=shapes[0]), label='x1') x2 = data.draw(xps.arrays(dtype=in_dtypes[1], shape=shapes[1]), label='x2') - locals_ = {"x1": x1, "x2": x2} - exec(f'x1 {op_symbol} x2', locals_) - x1 = locals_["x1"] - assert x1.dtype == out_dtype, f"{x1.dtype=!s}, but should be {out_dtype}" + locals_ = {'x1': x1, 'x2': x2} + exec(expr, locals_) + x1 = locals_['x1'] + assert x1.dtype == out_dtype, f'{x1.dtype=!s}, but should be {out_dtype}' scalar_promotion_parametrize_inputs = [ - pytest.param(func, dtype, scalar_type, id=f"{func}-{dtype}-{scalar_type.__name__}") + pytest.param(func, dtype, scalar_type, id=f'{func}-{dtype}-{scalar_type.__name__}') for func in sorted(set(dh.binary_op_to_symbol) - {'__matmul__'}) for dtype in dh.category_to_dtypes[dh.op_in_categories[func]] for scalar_type in dh.dtypes_to_scalars[dtype] ] -@pytest.mark.parametrize('func,dtype,scalar_type', - scalar_promotion_parametrize_inputs) + +@pytest.mark.parametrize('func,dtype,scalar_type', scalar_promotion_parametrize_inputs) @given(shape=hh.shapes, python_scalars=st.data(), data=st.data()) -def test_operator_scalar_arg_return_promoted(func, dtype, scalar_type, - shape, python_scalars, data): +def test_operator_scalar_arg_return_promoted( + func, dtype, scalar_type, shape, python_scalars, data +): """ - See https://st.data-apis.github.io/array-api/latest/API_specification/type_promotion.html#mixing-arrays-with-python-hh.scalars + See https://data-apis.github.io/array-api/latest/API_specification/type_promotion.html#mixing-arrays-with-python-scalars """ op = dh.binary_op_to_symbol[func] if op == '@': - pytest.skip("matmul (@) is not supported for hh.scalars") + pytest.skip('matmul (@) is not supported for hh.scalars') if dtype in dh.category_to_dtypes['integer']: s = python_scalars.draw(st.integers(*ah.dtype_ranges[dtype])) @@ -262,4 +269,4 @@ def test_operator_scalar_arg_return_promoted(func, dtype, scalar_type, if __name__ == '__main__': for (i, j), p in dh.promotion_table.items(): - print(f"({i}, {j}) -> {p}") + print(f'({i}, {j}) -> {p}') From 6efe39dbc5d402c8851de3e567250a81c319ddd6 Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Tue, 12 Oct 2021 17:03:55 +0100 Subject: [PATCH 16/41] Remove unused `pytest` import in `test_array_helpers.py` --- array_api_tests/meta_tests/test_array_helpers.py | 1 - 1 file changed, 1 deletion(-) diff --git a/array_api_tests/meta_tests/test_array_helpers.py b/array_api_tests/meta_tests/test_array_helpers.py index 1a2ae832..6a6b4849 100644 --- a/array_api_tests/meta_tests/test_array_helpers.py +++ b/array_api_tests/meta_tests/test_array_helpers.py @@ -1,4 +1,3 @@ -import pytest from hypothesis import given, assume from hypothesis.strategies import integers From de5dd5ac16ce133608baead10d0ad69a70e40a33 Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Tue, 12 Oct 2021 17:32:08 +0100 Subject: [PATCH 17/41] Use a `defaultdict` for storing filters --- array_api_tests/test_type_promotion.py | 38 ++++++++++++++------------ 1 file changed, 20 insertions(+), 18 deletions(-) diff --git a/array_api_tests/test_type_promotion.py b/array_api_tests/test_type_promotion.py index 81b620b7..fdf3bc6a 100644 --- a/array_api_tests/test_type_promotion.py +++ b/array_api_tests/test_type_promotion.py @@ -1,7 +1,7 @@ """ https://data-apis.github.io/array-api/latest/API_specification/type_promotion.html """ -import re +from collections import defaultdict from typing import Iterator import pytest @@ -16,21 +16,23 @@ from .function_stubs import elementwise_functions from .pytest_helpers import nargs -# We apply filters to xps.arrays() when behaviour for certain array elements is -# erroneous or undefined. - -filters = { - re.compile('bitwise_[a-z]+_shift'): lambda x: ah.all(x > 0), - re.compile('__[rl]shift__'): lambda x: ah.all(x > 0), -} +bitwise_shift_funcs = [ + 'bitwise_left_shift', + 'bitwise_right_shift', + '__lshift__', + '__rshift__', + '__ilshift__', + '__irshift__', +] -def get_filter(name): - for regex, cond in filters.items(): - if regex.match(name): - return cond - return lambda _: True +# We apply filters to xps.arrays() so we don't generate array elements that +# are erroneous or undefined for a function/operator. +filters = defaultdict( + lambda: lambda _: True, + {func: lambda x: ah.all(x > 0) for func in bitwise_shift_funcs}, +) def generate_func_params() -> Iterator: @@ -47,7 +49,7 @@ def generate_func_params() -> Iterator: func, (in_dtype,), out_dtype, - get_filter(func_name), + filters[func_name], id=f'{func_name}({in_dtype}) -> {out_dtype}', ) elif ndtypes == 2: @@ -60,7 +62,7 @@ def generate_func_params() -> Iterator: func, (in_dtype1, in_dtype2), out_dtype, - get_filter(func_name), + filters[func_name], id=f'{func_name}({in_dtype1}, {in_dtype2}) -> {out_dtype}', ) else: @@ -111,7 +113,7 @@ def generate_op_params() -> Iterator: f'{symbol}x', (in_dtype,), out_dtype, - get_filter(op), + filters[op], id=f'{op}({in_dtype}) -> {out_dtype}', ) else: @@ -124,7 +126,7 @@ def generate_op_params() -> Iterator: f'x1 {symbol} x2', (in_dtype1, in_dtype2), out_dtype, - get_filter(op), + filters[op], id=f'{op}({in_dtype1}, {in_dtype2}) -> {out_dtype}', ) # We generate params for abs seperately as it does not have an associated symbol @@ -133,7 +135,7 @@ def generate_op_params() -> Iterator: 'abs(x)', (in_dtype,), in_dtype, - get_filter('__abs__'), + filters['__abs__'], id=f'__abs__({in_dtype}) -> {in_dtype}', ) From 331e95cae706b96e46e4dc5aaa5850b8b914be93 Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Tue, 12 Oct 2021 17:44:01 +0100 Subject: [PATCH 18/41] Type hint parameter generators --- array_api_tests/test_type_promotion.py | 31 +++++++++++++------------- 1 file changed, 15 insertions(+), 16 deletions(-) diff --git a/array_api_tests/test_type_promotion.py b/array_api_tests/test_type_promotion.py index fdf3bc6a..1cf99f09 100644 --- a/array_api_tests/test_type_promotion.py +++ b/array_api_tests/test_type_promotion.py @@ -2,7 +2,7 @@ https://data-apis.github.io/array-api/latest/API_specification/type_promotion.html """ from collections import defaultdict -from typing import Iterator +from typing import Iterator, TypeVar, Tuple, Callable import pytest from hypothesis import assume, given @@ -27,6 +27,9 @@ ] +DT = TypeVar('DT') + + # We apply filters to xps.arrays() so we don't generate array elements that # are erroneous or undefined for a function/operator. filters = defaultdict( @@ -35,7 +38,7 @@ ) -def generate_func_params() -> Iterator: +def generate_func_params() -> Iterator[Tuple[Callable, Tuple[DT, ...], DT, Callable]]: for func_name in elementwise_functions.__all__: func = getattr(xp, func_name) in_category = dh.func_in_categories[func_name] @@ -69,16 +72,14 @@ def generate_func_params() -> Iterator: raise NotImplementedError() -@pytest.mark.parametrize( - 'func, in_dtypes, out_dtype, arrays_filter', generate_func_params() -) +@pytest.mark.parametrize('func, in_dtypes, out_dtype, x_filter', generate_func_params()) @given(data=st.data()) def test_func_returns_array_with_correct_dtype( - func, in_dtypes, out_dtype, arrays_filter, data + func, in_dtypes, out_dtype, x_filter, data ): if len(in_dtypes) == 1: x = data.draw( - xps.arrays(dtype=in_dtypes[0], shape=hh.shapes).filter(arrays_filter), + xps.arrays(dtype=in_dtypes[0], shape=hh.shapes).filter(x_filter), label='x', ) out = func(x) @@ -89,7 +90,7 @@ def test_func_returns_array_with_correct_dtype( ) for i, (dtype, shape) in enumerate(zip(in_dtypes, shapes), 1): x = data.draw( - xps.arrays(dtype=dtype, shape=shape).filter(arrays_filter), + xps.arrays(dtype=dtype, shape=shape).filter(x_filter), label=f'x{i}', ) arrays.append(x) @@ -97,7 +98,7 @@ def test_func_returns_array_with_correct_dtype( assert out.dtype == out_dtype, f'{out.dtype=!s}, but should be {out_dtype}' -def generate_op_params() -> Iterator: +def generate_op_params() -> Iterator[Tuple[str, Tuple[DT, ...], DT, Callable]]: op_to_symbol = {**dh.unary_op_to_symbol, **dh.binary_op_to_symbol} for op, symbol in op_to_symbol.items(): if op == '__matmul__': @@ -140,16 +141,14 @@ def generate_op_params() -> Iterator: ) -@pytest.mark.parametrize( - 'expr, in_dtypes, out_dtype, arrays_filter', generate_op_params() -) +@pytest.mark.parametrize('expr, in_dtypes, out_dtype, x_filter', generate_op_params()) @given(data=st.data()) def test_operator_returns_array_with_correct_dtype( - expr, in_dtypes, out_dtype, arrays_filter, data + expr, in_dtypes, out_dtype, x_filter, data ): if len(in_dtypes) == 1: x = data.draw( - xps.arrays(dtype=in_dtypes[0], shape=hh.shapes).filter(arrays_filter), + xps.arrays(dtype=in_dtypes[0], shape=hh.shapes).filter(x_filter), label='x', ) out = eval(expr, {'x': x}) @@ -160,14 +159,14 @@ def test_operator_returns_array_with_correct_dtype( ) for i, (dtype, shape) in enumerate(zip(in_dtypes, shapes), 1): locals_[f'x{i}'] = data.draw( - xps.arrays(dtype=dtype, shape=shape).filter(arrays_filter), + xps.arrays(dtype=dtype, shape=shape).filter(x_filter), label=f'x{i}', ) out = eval(expr, locals_) assert out.dtype == out_dtype, f'{out.dtype=!s}, but should be {out_dtype}' -def generate_inplace_op_params() -> Iterator: +def generate_inplace_op_params() -> Iterator[Tuple[str, Tuple[DT, ...], DT, Callable]]: for op, symbol in dh.binary_op_to_symbol.items(): if op == '__matmul__' or dh.op_out_categories[op] == 'bool': continue From 09d22813648a6c24c849946cb2aa2d63e1a7e040 Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Tue, 12 Oct 2021 18:04:03 +0100 Subject: [PATCH 19/41] Add filters to inplace operators test --- array_api_tests/test_type_promotion.py | 28 +++++++++++++++++--------- 1 file changed, 18 insertions(+), 10 deletions(-) diff --git a/array_api_tests/test_type_promotion.py b/array_api_tests/test_type_promotion.py index 1cf99f09..40940120 100644 --- a/array_api_tests/test_type_promotion.py +++ b/array_api_tests/test_type_promotion.py @@ -38,7 +38,7 @@ ) -def generate_func_params() -> Iterator[Tuple[Callable, Tuple[DT, ...], DT, Callable]]: +def gen_func_params() -> Iterator[Tuple[Callable, Tuple[DT, ...], DT, Callable]]: for func_name in elementwise_functions.__all__: func = getattr(xp, func_name) in_category = dh.func_in_categories[func_name] @@ -72,7 +72,7 @@ def generate_func_params() -> Iterator[Tuple[Callable, Tuple[DT, ...], DT, Calla raise NotImplementedError() -@pytest.mark.parametrize('func, in_dtypes, out_dtype, x_filter', generate_func_params()) +@pytest.mark.parametrize('func, in_dtypes, out_dtype, x_filter', gen_func_params()) @given(data=st.data()) def test_func_returns_array_with_correct_dtype( func, in_dtypes, out_dtype, x_filter, data @@ -98,7 +98,7 @@ def test_func_returns_array_with_correct_dtype( assert out.dtype == out_dtype, f'{out.dtype=!s}, but should be {out_dtype}' -def generate_op_params() -> Iterator[Tuple[str, Tuple[DT, ...], DT, Callable]]: +def gen_op_params() -> Iterator[Tuple[str, Tuple[DT, ...], DT, Callable]]: op_to_symbol = {**dh.unary_op_to_symbol, **dh.binary_op_to_symbol} for op, symbol in op_to_symbol.items(): if op == '__matmul__': @@ -141,7 +141,7 @@ def generate_op_params() -> Iterator[Tuple[str, Tuple[DT, ...], DT, Callable]]: ) -@pytest.mark.parametrize('expr, in_dtypes, out_dtype, x_filter', generate_op_params()) +@pytest.mark.parametrize('expr, in_dtypes, out_dtype, x_filter', gen_op_params()) @given(data=st.data()) def test_operator_returns_array_with_correct_dtype( expr, in_dtypes, out_dtype, x_filter, data @@ -166,7 +166,7 @@ def test_operator_returns_array_with_correct_dtype( assert out.dtype == out_dtype, f'{out.dtype=!s}, but should be {out_dtype}' -def generate_inplace_op_params() -> Iterator[Tuple[str, Tuple[DT, ...], DT, Callable]]: +def gen_inplace_params() -> Iterator[Tuple[str, Tuple[DT, ...], DT, Callable]]: for op, symbol in dh.binary_op_to_symbol.items(): if op == '__matmul__' or dh.op_out_categories[op] == 'bool': continue @@ -178,22 +178,30 @@ def generate_inplace_op_params() -> Iterator[Tuple[str, Tuple[DT, ...], DT, Call and in_dtype1 in valid_in_dtypes and in_dtype2 in valid_in_dtypes ): + iop = f'__i{op[2:]}' yield pytest.param( f'x1 {symbol}= x2', (in_dtype1, in_dtype2), promoted_dtype, - id=f'__i{op[2:]}({in_dtype1}, {in_dtype2}) -> {in_dtype1}', + filters[iop], + id=f'{iop}({in_dtype1}, {in_dtype2}) -> {in_dtype1}', ) -@pytest.mark.parametrize('expr, in_dtypes, out_dtype', generate_inplace_op_params()) +@pytest.mark.parametrize('expr, in_dtypes, out_dtype, x_filter', gen_inplace_params()) @given(shapes=hh.mutually_broadcastable_shapes(2), data=st.data()) def test_inplace_operator_returns_array_with_correct_dtype( - expr, in_dtypes, out_dtype, shapes, data + expr, in_dtypes, out_dtype, x_filter, shapes, data ): assume(len(shapes[0]) >= len(shapes[1])) - x1 = data.draw(xps.arrays(dtype=in_dtypes[0], shape=shapes[0]), label='x1') - x2 = data.draw(xps.arrays(dtype=in_dtypes[1], shape=shapes[1]), label='x2') + x1 = data.draw( + xps.arrays(dtype=in_dtypes[0], shape=shapes[0]).filter(x_filter), + label='x1', + ) + x2 = data.draw( + xps.arrays(dtype=in_dtypes[1], shape=shapes[1]).filter(x_filter), + label='x2', + ) locals_ = {'x1': x1, 'x2': x2} exec(expr, locals_) x1 = locals_['x1'] From 5a0607872e3d64be7f6b538290e836487876f4dc Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Tue, 12 Oct 2021 18:09:46 +0100 Subject: [PATCH 20/41] Make `op_to_func` public, use `promotion_table` in linalg tests --- array_api_tests/dtype_helpers.py | 5 +++-- array_api_tests/test_linalg.py | 7 ++++--- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/array_api_tests/dtype_helpers.py b/array_api_tests/dtype_helpers.py index 287a99ba..a02cfa23 100644 --- a/array_api_tests/dtype_helpers.py +++ b/array_api_tests/dtype_helpers.py @@ -13,6 +13,7 @@ 'op_out_categories', 'binary_op_to_symbol', 'unary_op_to_symbol', + 'op_to_func', ] @@ -254,7 +255,7 @@ } -_operator_to_elementwise = { +op_to_func = { '__abs__': 'abs', '__add__': 'add', '__and__': 'bitwise_and', @@ -283,6 +284,6 @@ op_in_categories = {} op_out_categories = {} -for op_func, elwise_func in _operator_to_elementwise.items(): +for op_func, elwise_func in op_to_func.items(): op_in_categories[op_func] = func_in_categories[elwise_func] op_out_categories[op_func] = func_out_categories[elwise_func] diff --git a/array_api_tests/test_linalg.py b/array_api_tests/test_linalg.py index cbf22429..fb8149bd 100644 --- a/array_api_tests/test_linalg.py +++ b/array_api_tests/test_linalg.py @@ -17,13 +17,14 @@ from hypothesis.strategies import booleans, composite, none, integers, shared from .array_helpers import (assert_exactly_equal, ndindex, asarray, - numeric_dtype_objects, promote_dtypes) + numeric_dtype_objects) from .hypothesis_helpers import (xps, dtypes, shapes, kwargs, matrix_shapes, square_matrix_shapes, symmetric_matrices, positive_definite_matrices, MAX_ARRAY_SIZE, invertible_matrices, two_mutual_arrays, mutually_promotable_dtypes) from .pytest_helpers import raises +from .dtype_helpers import promotion_table from .test_broadcasting import broadcast_shapes @@ -132,7 +133,7 @@ def test_cross(x1_x2_kw): res = linalg.cross(x1, x2, **kw) - assert res.dtype == promote_dtypes(x1, x2), "cross() did not return the correct dtype" + assert res.dtype == promotion_table[x1, x2], "cross() did not return the correct dtype" assert res.shape == shape, "cross() did not return the correct shape" # cross is too different from other functions to use _test_stacks, and it @@ -268,7 +269,7 @@ def test_matmul(x1, x2): else: res = linalg.matmul(x1, x2) - assert res.dtype == promote_dtypes(x1, x2), "matmul() did not return the correct dtype" + assert res.dtype == promotion_table[x1, x2], "matmul() did not return the correct dtype" if len(x1.shape) == len(x2.shape) == 1: assert res.shape == () From 357aa9493247d79c4a3df0b34449edb874b424ad Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Tue, 12 Oct 2021 21:54:26 +0100 Subject: [PATCH 21/41] Remove unnecessary TODO comment --- array_api_tests/dtype_helpers.py | 1 - 1 file changed, 1 deletion(-) diff --git a/array_api_tests/dtype_helpers.py b/array_api_tests/dtype_helpers.py index a02cfa23..b8d375d9 100644 --- a/array_api_tests/dtype_helpers.py +++ b/array_api_tests/dtype_helpers.py @@ -86,7 +86,6 @@ promotion_table = { (xp.bool, xp.bool): xp.bool, **_numeric_promotions, - # TODO: dont unpack pairs of the same dtype **{(d2, d1): res for (d1, d2), res in _numeric_promotions.items()}, } From 26cf341ea51565d74b224a4c9c5c85bf85a7a89b Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Wed, 13 Oct 2021 12:58:40 +0100 Subject: [PATCH 22/41] Move `ah*_objects` and `ah.dtype_ranges` into `dtype_helpers` --- array_api_tests/array_helpers.py | 25 +-------- array_api_tests/dtype_helpers.py | 21 ++++++++ array_api_tests/hypothesis_helpers.py | 36 ++++++------- .../meta_tests/test_hypothesis_helpers.py | 3 +- array_api_tests/test_creation_functions.py | 3 +- array_api_tests/test_elementwise_functions.py | 51 ++++++++++--------- array_api_tests/test_linalg.py | 13 +++-- array_api_tests/test_type_promotion.py | 2 +- 8 files changed, 76 insertions(+), 78 deletions(-) diff --git a/array_api_tests/array_helpers.py b/array_api_tests/array_helpers.py index 9d00b36b..60db4725 100644 --- a/array_api_tests/array_helpers.py +++ b/array_api_tests/array_helpers.py @@ -5,10 +5,7 @@ zeros, ones, full, bool, int8, int16, int32, int64, uint8, uint16, uint32, uint64, float32, float64, nan, inf, pi, remainder, divide, isinf, - negative, _integer_dtypes, _floating_dtypes, - _numeric_dtypes, _boolean_dtypes, _dtypes, - asarray) -from . import _array_module + negative, asarray) # These are exported here so that they can be included in the special cases # tests from this file. @@ -27,7 +24,7 @@ 'assert_positive_mathematical_sign', 'negative_mathematical_sign', 'assert_negative_mathematical_sign', 'same_sign', 'assert_same_sign', 'ndindex', 'float64', - 'asarray', 'is_integer_dtype', 'is_float_dtype', 'dtype_ranges', + 'asarray', 'is_integer_dtype', 'is_float_dtype', 'full', 'true', 'false', 'isnan'] def zero(shape, dtype): @@ -312,13 +309,6 @@ def same_sign(x, y): def assert_same_sign(x, y): assert all(same_sign(x, y)), "The input arrays do not have the same sign" -integer_dtype_objects = [getattr(_array_module, t) for t in _integer_dtypes] -floating_dtype_objects = [getattr(_array_module, t) for t in _floating_dtypes] -numeric_dtype_objects = [getattr(_array_module, t) for t in _numeric_dtypes] -boolean_dtype_objects = [getattr(_array_module, t) for t in _boolean_dtypes] -integer_or_boolean_dtype_objects = integer_dtype_objects + boolean_dtype_objects -dtype_objects = [getattr(_array_module, t) for t in _dtypes] - def is_integer_dtype(dtype): if dtype is None: return False @@ -332,17 +322,6 @@ def is_float_dtype(dtype): # spec, like np.float16 return dtype in [float32, float64] -dtype_ranges = { - int8: [-128, +127], - int16: [-32_768, +32_767], - int32: [-2_147_483_648, +2_147_483_647], - int64: [-9_223_372_036_854_775_808, +9_223_372_036_854_775_807], - uint8: [0, +255], - uint16: [0, +65_535], - uint32: [0, +4_294_967_295], - uint64: [0, +18_446_744_073_709_551_615], -} - def int_to_dtype(x, n, signed): """ Convert the Python integer x into an n bit signed or unsigned number. diff --git a/array_api_tests/dtype_helpers.py b/array_api_tests/dtype_helpers.py index b8d375d9..2b8c5c11 100644 --- a/array_api_tests/dtype_helpers.py +++ b/array_api_tests/dtype_helpers.py @@ -2,7 +2,15 @@ __all__ = [ + 'int_dtypes', + 'uint_dtypes', + 'all_int_dtypes', + 'float_dtypes', + 'numeric_dtypes', + 'all_dtypes', + 'bool_and_all_int_dtypes', 'dtypes_to_scalars', + 'dtype_ranges', 'category_to_dtypes', 'promotion_table', 'dtype_nbits', @@ -23,6 +31,7 @@ float_dtypes = (xp.float32, xp.float64) numeric_dtypes = all_int_dtypes + float_dtypes all_dtypes = (xp.bool,) + numeric_dtypes +bool_and_all_int_dtypes = (xp.bool,) + all_int_dtypes dtypes_to_scalars = { @@ -32,6 +41,18 @@ } +dtype_ranges = { + xp.int8: [-128, +127], + xp.int16: [-32_768, +32_767], + xp.int32: [-2_147_483_648, +2_147_483_647], + xp.int64: [-9_223_372_036_854_775_808, +9_223_372_036_854_775_807], + xp.uint8: [0, +255], + xp.uint16: [0, +65_535], + xp.uint32: [0, +4_294_967_295], + xp.uint64: [0, +18_446_744_073_709_551_615], +} + + category_to_dtypes = { 'any': all_dtypes, 'boolean': (xp.bool,), diff --git a/array_api_tests/hypothesis_helpers.py b/array_api_tests/hypothesis_helpers.py index c853a871..108d40ea 100644 --- a/array_api_tests/hypothesis_helpers.py +++ b/array_api_tests/hypothesis_helpers.py @@ -11,12 +11,8 @@ from hypothesis.strategies._internal.strategies import SearchStrategy from .pytest_helpers import nargs -from .array_helpers import (dtype_ranges, integer_dtype_objects, - floating_dtype_objects, numeric_dtype_objects, - boolean_dtype_objects, - integer_or_boolean_dtype_objects, dtype_objects, - ndindex) -from .dtype_helpers import promotion_table +from .array_helpers import ndindex +from . import dtype_helpers as dh from ._array_module import (full, float32, float64, bool as bool_dtype, _UndefinedStub, eye, broadcast_to) from . import _array_module as xp @@ -32,12 +28,12 @@ # places in the tests. FILTER_UNDEFINED_DTYPES = True -integer_dtypes = sampled_from(integer_dtype_objects) -floating_dtypes = sampled_from(floating_dtype_objects) -numeric_dtypes = sampled_from(numeric_dtype_objects) -integer_or_boolean_dtypes = sampled_from(integer_or_boolean_dtype_objects) -boolean_dtypes = sampled_from(boolean_dtype_objects) -dtypes = sampled_from(dtype_objects) +integer_dtypes = sampled_from(dh.all_int_dtypes) +floating_dtypes = sampled_from(dh.float_dtypes) +numeric_dtypes = sampled_from(dh.numeric_dtypes) +integer_or_boolean_dtypes = sampled_from(dh.bool_and_all_int_dtypes) +boolean_dtypes = just(xp.bool) +dtypes = sampled_from(dh.all_dtypes) if FILTER_UNDEFINED_DTYPES: integer_dtypes = integer_dtypes.filter(lambda x: not isinstance(x, _UndefinedStub)) @@ -52,7 +48,7 @@ shared_floating_dtypes = shared(floating_dtypes, key="dtype") -sorted_table = sorted(promotion_table) +sorted_table = sorted(dh.promotion_table) sorted_table = sorted( sorted_table, key=lambda ij: -1 if ij[0] == ij[1] else sorted_table.index(ij) ) @@ -62,7 +58,7 @@ and not isinstance(j, _UndefinedStub)] -def mutually_promotable_dtypes(dtype_objects=dtype_objects): +def mutually_promotable_dtypes(dtype_objs=dh.all_dtypes): # sort for shrinking (sampled_from shrinks to the earlier elements in the # list). Give pairs of the same dtypes first, then smaller dtypes, # preferring float, then int, then unsigned int. Note, this might not @@ -71,7 +67,7 @@ def mutually_promotable_dtypes(dtype_objects=dtype_objects): # pairs (XXX: Can we redesign the strategies so that they can prefer # shrinking dtypes over values?) return sampled_from( - [(i, j) for i, j in sorted_table if i in dtype_objects and j in dtype_objects] + [(i, j) for i, j in sorted_table if i in dtype_objs and j in dtype_objs] ) # shared() allows us to draw either the function or the function name and they @@ -195,8 +191,8 @@ def scalars(draw, dtypes, finite=False): dtypes should be one of the shared_* dtypes strategies. """ dtype = draw(dtypes) - if dtype in dtype_ranges: - m, M = dtype_ranges[dtype] + if dtype in dh.dtype_ranges: + m, M = dh.dtype_ranges[dtype] return draw(integers(m, M)) elif dtype == bool_dtype: return draw(booleans()) @@ -228,7 +224,7 @@ def integer_indices(draw, sizes): # Return either a Python integer or a 0-D array with some integer dtype idx = draw(python_integer_indices(sizes)) dtype = draw(integer_dtypes) - m, M = dtype_ranges[dtype] + m, M = dh.dtype_ranges[dtype] if m <= idx <= M: return draw(one_of(just(idx), just(full((), idx, dtype=dtype)))) @@ -297,8 +293,8 @@ def multiaxis_indices(draw, shapes): return tuple(res) -def two_mutual_arrays(dtype_objects=dtype_objects): - mutual_dtypes = shared(mutually_promotable_dtypes(dtype_objects)) +def two_mutual_arrays(dtype_objs=dh.all_dtypes): + mutual_dtypes = shared(mutually_promotable_dtypes(dtype_objs)) mutual_shapes = shared(two_mutually_broadcastable_shapes) arrays1 = xps.arrays( dtype=mutual_dtypes.map(lambda pair: pair[0]), diff --git a/array_api_tests/meta_tests/test_hypothesis_helpers.py b/array_api_tests/meta_tests/test_hypothesis_helpers.py index 6af43a26..93a63f8e 100644 --- a/array_api_tests/meta_tests/test_hypothesis_helpers.py +++ b/array_api_tests/meta_tests/test_hypothesis_helpers.py @@ -6,11 +6,12 @@ from .. import _array_module as xp from .._array_module import _UndefinedStub from .. import array_helpers as ah +from .. import dtype_helpers as dh from .. import hypothesis_helpers as hh from ..test_broadcasting import broadcast_shapes from ..test_elementwise_functions import sanity_check -UNDEFINED_DTYPES = any(isinstance(d, _UndefinedStub) for d in ah.dtype_objects) +UNDEFINED_DTYPES = any(isinstance(d, _UndefinedStub) for d in dh.all_dtypes) pytestmark = [pytest.mark.skipif(UNDEFINED_DTYPES, reason="undefined dtypes")] @given(hh.mutually_promotable_dtypes([xp.float32, xp.float64])) diff --git a/array_api_tests/test_creation_functions.py b/array_api_tests/test_creation_functions.py index 91b93882..690e2182 100644 --- a/array_api_tests/test_creation_functions.py +++ b/array_api_tests/test_creation_functions.py @@ -4,11 +4,12 @@ full_like, equal, all, linspace, ones, ones_like, zeros, zeros_like, isnan) from . import _array_module as xp -from .array_helpers import (is_integer_dtype, dtype_ranges, +from .array_helpers import (is_integer_dtype, assert_exactly_equal, isintegral, is_float_dtype) from .hypothesis_helpers import (numeric_dtypes, dtypes, MAX_ARRAY_SIZE, shapes, sizes, sqrt_sizes, shared_dtypes, scalars, kwargs) +from . dtype_helpers import dtype_ranges from . import xps from hypothesis import assume, given diff --git a/array_api_tests/test_elementwise_functions.py b/array_api_tests/test_elementwise_functions.py index 443ad28b..e4e2d5f8 100644 --- a/array_api_tests/test_elementwise_functions.py +++ b/array_api_tests/test_elementwise_functions.py @@ -17,6 +17,7 @@ from . import _array_module as xp from . import array_helpers as ah from . import hypothesis_helpers as hh +from . import dtype_helpers as dh from . import xps from .dtype_helpers import dtype_nbits, dtype_signed, promotion_table # We might as well use this implementation rather than requiring @@ -29,11 +30,11 @@ integer_or_boolean_scalars = hh.array_scalars(hh.integer_or_boolean_dtypes) boolean_scalars = hh.array_scalars(hh.boolean_dtypes) -two_integer_dtypes = hh.mutually_promotable_dtypes(hh.integer_dtype_objects) -two_floating_dtypes = hh.mutually_promotable_dtypes(hh.floating_dtype_objects) -two_numeric_dtypes = hh.mutually_promotable_dtypes(hh.numeric_dtype_objects) -two_integer_or_boolean_dtypes = hh.mutually_promotable_dtypes(hh.integer_or_boolean_dtype_objects) -two_boolean_dtypes = hh.mutually_promotable_dtypes(hh.boolean_dtype_objects) +two_integer_dtypes = hh.mutually_promotable_dtypes(dh.all_int_dtypes) +two_floating_dtypes = hh.mutually_promotable_dtypes(dh.float_dtypes) +two_numeric_dtypes = hh.mutually_promotable_dtypes(dh.numeric_dtypes) +two_integer_or_boolean_dtypes = hh.mutually_promotable_dtypes(dh.bool_and_all_int_dtypes) +two_boolean_dtypes = hh.mutually_promotable_dtypes((xp.bool,)) two_any_dtypes = hh.mutually_promotable_dtypes() @st.composite @@ -52,7 +53,7 @@ def sanity_check(x1, x2): @given(xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes)) def test_abs(x): if ah.is_integer_dtype(x.dtype): - minval = ah.dtype_ranges[x.dtype][0] + minval = dh.dtype_ranges[x.dtype][0] if minval < 0: # abs of the smallest representable negative integer is not defined mask = xp.not_equal(x, ah.full(x.shape, minval, dtype=x.dtype)) @@ -92,7 +93,7 @@ def test_acosh(x): # to nan, which is already tested in the special cases. ah.assert_exactly_equal(domain, codomain) -@given(*hh.two_mutual_arrays(hh.numeric_dtype_objects)) +@given(*hh.two_mutual_arrays(dh.numeric_dtypes)) def test_add(x1, x2): sanity_check(x1, x2) a = xp.add(x1, x2) @@ -134,7 +135,7 @@ def test_atan(x): # mapped to nan, which is already tested in the special cases. ah.assert_exactly_equal(domain, codomain) -@given(*hh.two_mutual_arrays(hh.floating_dtype_objects)) +@given(*hh.two_mutual_arrays(dh.float_dtypes)) def test_atan2(x1, x2): sanity_check(x1, x2) a = xp.atan2(x1, x2) @@ -181,7 +182,7 @@ def test_atanh(x): # mapped to nan, which is already tested in the special cases. ah.assert_exactly_equal(domain, codomain) -@given(*hh.two_mutual_arrays(ah.integer_or_boolean_dtype_objects)) +@given(*hh.two_mutual_arrays(dh.bool_and_all_int_dtypes)) def test_bitwise_and(x1, x2): sanity_check(x1, x2) out = xp.bitwise_and(x1, x2) @@ -208,7 +209,7 @@ def test_bitwise_and(x1, x2): vals_and = ah.int_to_dtype(vals_and, dtype_nbits[out.dtype], dtype_signed[out.dtype]) assert vals_and == res -@given(*hh.two_mutual_arrays(ah.integer_dtype_objects)) +@given(*hh.two_mutual_arrays(dh.all_int_dtypes)) def test_bitwise_left_shift(x1, x2): sanity_check(x1, x2) assume(not ah.any(ah.isnegative(x2))) @@ -247,7 +248,7 @@ def test_bitwise_invert(x): val_invert = ah.int_to_dtype(val_invert, dtype_nbits[out.dtype], dtype_signed[out.dtype]) assert val_invert == res -@given(*hh.two_mutual_arrays(ah.integer_or_boolean_dtype_objects)) +@given(*hh.two_mutual_arrays(dh.bool_and_all_int_dtypes)) def test_bitwise_or(x1, x2): sanity_check(x1, x2) out = xp.bitwise_or(x1, x2) @@ -274,7 +275,7 @@ def test_bitwise_or(x1, x2): vals_or = ah.int_to_dtype(vals_or, dtype_nbits[out.dtype], dtype_signed[out.dtype]) assert vals_or == res -@given(*hh.two_mutual_arrays(ah.integer_dtype_objects)) +@given(*hh.two_mutual_arrays(dh.all_int_dtypes)) def test_bitwise_right_shift(x1, x2): sanity_check(x1, x2) assume(not ah.any(ah.isnegative(x2))) @@ -295,7 +296,7 @@ def test_bitwise_right_shift(x1, x2): vals_shift = ah.int_to_dtype(vals_shift, dtype_nbits[out.dtype], dtype_signed[out.dtype]) assert vals_shift == res -@given(*hh.two_mutual_arrays(ah.integer_or_boolean_dtype_objects)) +@given(*hh.two_mutual_arrays(dh.bool_and_all_int_dtypes)) def test_bitwise_xor(x1, x2): sanity_check(x1, x2) out = xp.bitwise_xor(x1, x2) @@ -354,7 +355,7 @@ def test_cosh(x): # mapped to nan, which is already tested in the special cases. ah.assert_exactly_equal(domain, codomain) -@given(*hh.two_mutual_arrays(hh.floating_dtype_objects)) +@given(*hh.two_mutual_arrays(dh.float_dtypes)) def test_divide(x1, x2): sanity_check(x1, x2) xp.divide(x1, x2) @@ -447,7 +448,7 @@ def test_floor(x): integers = ah.isintegral(x) ah.assert_exactly_equal(a[integers], x[integers]) -@given(*hh.two_mutual_arrays(hh.numeric_dtype_objects)) +@given(*hh.two_mutual_arrays(dh.numeric_dtypes)) def test_floor_divide(x1, x2): sanity_check(x1, x2) if ah.is_integer_dtype(x1.dtype): @@ -471,7 +472,7 @@ def test_floor_divide(x1, x2): # TODO: Test the exact output for floor_divide. -@given(*hh.two_mutual_arrays(hh.numeric_dtype_objects)) +@given(*hh.two_mutual_arrays(dh.numeric_dtypes)) def test_greater(x1, x2): sanity_check(x1, x2) a = xp.greater(x1, x2) @@ -500,7 +501,7 @@ def test_greater(x1, x2): assert aidx.shape == x1idx.shape == x2idx.shape assert bool(aidx) == (scalar_func(x1idx) > scalar_func(x2idx)) -@given(*hh.two_mutual_arrays(hh.numeric_dtype_objects)) +@given(*hh.two_mutual_arrays(dh.numeric_dtypes)) def test_greater_equal(x1, x2): sanity_check(x1, x2) a = xp.greater_equal(x1, x2) @@ -575,7 +576,7 @@ def test_isnan(x): s = float(x[idx]) assert bool(a[idx]) == math.isnan(s) -@given(*hh.two_mutual_arrays(hh.numeric_dtype_objects)) +@given(*hh.two_mutual_arrays(dh.numeric_dtypes)) def test_less(x1, x2): sanity_check(x1, x2) a = ah.less(x1, x2) @@ -604,7 +605,7 @@ def test_less(x1, x2): assert aidx.shape == x1idx.shape == x2idx.shape assert bool(aidx) == (scalar_func(x1idx) < scalar_func(x2idx)) -@given(*hh.two_mutual_arrays(hh.numeric_dtype_objects)) +@given(*hh.two_mutual_arrays(dh.numeric_dtypes)) def test_less_equal(x1, x2): sanity_check(x1, x2) a = ah.less_equal(x1, x2) @@ -677,7 +678,7 @@ def test_log10(x): # mapped to nan, which is already tested in the special cases. ah.assert_exactly_equal(domain, codomain) -@given(*hh.two_mutual_arrays(hh.floating_dtype_objects)) +@given(*hh.two_mutual_arrays(dh.float_dtypes)) def test_logaddexp(x1, x2): sanity_check(x1, x2) xp.logaddexp(x1, x2) @@ -731,7 +732,7 @@ def test_logical_xor(x1, x2): for idx in ah.ndindex(shape): assert a[idx] == (bool(_x1[idx]) ^ bool(_x2[idx])) -@given(*hh.two_mutual_arrays(hh.numeric_dtype_objects)) +@given(*hh.two_mutual_arrays(dh.numeric_dtypes)) def test_multiply(x1, x2): sanity_check(x1, x2) a = xp.multiply(x1, x2) @@ -749,7 +750,7 @@ def test_negative(x): mask = ah.isfinite(x) if ah.is_integer_dtype(x.dtype): - minval = ah.dtype_ranges[x.dtype][0] + minval = dh.dtype_ranges[x.dtype][0] if minval < 0: # negative of the smallest representable negative integer is not defined mask = xp.not_equal(x, ah.full(x.shape, minval, dtype=x.dtype)) @@ -796,7 +797,7 @@ def test_positive(x): # Positive does nothing ah.assert_exactly_equal(out, x) -@given(*hh.two_mutual_arrays(hh.floating_dtype_objects)) +@given(*hh.two_mutual_arrays(dh.float_dtypes)) def test_pow(x1, x2): sanity_check(x1, x2) xp.pow(x1, x2) @@ -806,7 +807,7 @@ def test_pow(x1, x2): # numbers. We could test that this does implement IEEE 754 pow, but we # don't yet have those sorts in general for this module. -@given(*hh.two_mutual_arrays(hh.numeric_dtype_objects)) +@given(*hh.two_mutual_arrays(dh.numeric_dtypes)) def test_remainder(x1, x2): assume(len(x1.shape) <= len(x2.shape)) # TODO: rework same sign testing below to remove this sanity_check(x1, x2) @@ -886,7 +887,7 @@ def test_trunc(x): out = xp.trunc(x) assert out.dtype == x.dtype, f"{x.dtype=!s} but {out.dtype=!s}" assert out.shape == x.shape, f"{x.shape=} but {out.shape=}" - if x.dtype in hh.integer_dtype_objects: + if x.dtype in dh.all_int_dtypes: assert ah.all(ah.equal(x, out)), f"{x=!s} but {out=!s}" else: finite = ah.isfinite(x) diff --git a/array_api_tests/test_linalg.py b/array_api_tests/test_linalg.py index fb8149bd..8473d9c2 100644 --- a/array_api_tests/test_linalg.py +++ b/array_api_tests/test_linalg.py @@ -16,15 +16,14 @@ from hypothesis import assume, given from hypothesis.strategies import booleans, composite, none, integers, shared -from .array_helpers import (assert_exactly_equal, ndindex, asarray, - numeric_dtype_objects) +from .array_helpers import assert_exactly_equal, ndindex, asarray from .hypothesis_helpers import (xps, dtypes, shapes, kwargs, matrix_shapes, square_matrix_shapes, symmetric_matrices, positive_definite_matrices, MAX_ARRAY_SIZE, invertible_matrices, two_mutual_arrays, mutually_promotable_dtypes) from .pytest_helpers import raises -from .dtype_helpers import promotion_table +from . import dtype_helpers as dh from .test_broadcasting import broadcast_shapes @@ -92,7 +91,7 @@ def test_cholesky(x, kw): @composite -def cross_args(draw, dtype_objects=numeric_dtype_objects): +def cross_args(draw, dtype_objects=dh.numeric_dtypes): """ cross() requires two arrays with a size 3 in the 'axis' dimension @@ -133,7 +132,7 @@ def test_cross(x1_x2_kw): res = linalg.cross(x1, x2, **kw) - assert res.dtype == promotion_table[x1, x2], "cross() did not return the correct dtype" + assert res.dtype == dh.promotion_table[x1, x2], "cross() did not return the correct dtype" assert res.shape == shape, "cross() did not return the correct shape" # cross is too different from other functions to use _test_stacks, and it @@ -252,7 +251,7 @@ def test_inv(x): # TODO: Test that the result is actually the inverse @given( - *two_mutual_arrays(numeric_dtype_objects) + *two_mutual_arrays(dh.numeric_dtypes) ) def test_matmul(x1, x2): # TODO: Make this also test the @ operator @@ -269,7 +268,7 @@ def test_matmul(x1, x2): else: res = linalg.matmul(x1, x2) - assert res.dtype == promotion_table[x1, x2], "matmul() did not return the correct dtype" + assert res.dtype == dh.promotion_table[x1, x2], "matmul() did not return the correct dtype" if len(x1.shape) == len(x2.shape) == 1: assert res.shape == () diff --git a/array_api_tests/test_type_promotion.py b/array_api_tests/test_type_promotion.py index 40940120..e23fdc7d 100644 --- a/array_api_tests/test_type_promotion.py +++ b/array_api_tests/test_type_promotion.py @@ -229,7 +229,7 @@ def test_operator_scalar_arg_return_promoted( pytest.skip('matmul (@) is not supported for hh.scalars') if dtype in dh.category_to_dtypes['integer']: - s = python_scalars.draw(st.integers(*ah.dtype_ranges[dtype])) + s = python_scalars.draw(st.integers(*dh.dtype_ranges[dtype])) else: s = python_scalars.draw(st.from_type(scalar_type)) scalar_as_array = ah.asarray(s, dtype=dtype) From a0e410dc15515e1f6ba92986ed4b09ef935e5235 Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Wed, 13 Oct 2021 13:34:11 +0100 Subject: [PATCH 23/41] Move `ah.is_*_dtype` methods into `dtype_helpers` --- array_api_tests/array_helpers.py | 22 ++---- array_api_tests/dtype_helpers.py | 17 +++++ array_api_tests/test_creation_functions.py | 37 +++++----- array_api_tests/test_elementwise_functions.py | 71 +++++++++---------- 4 files changed, 74 insertions(+), 73 deletions(-) diff --git a/array_api_tests/array_helpers.py b/array_api_tests/array_helpers.py index 60db4725..398f1994 100644 --- a/array_api_tests/array_helpers.py +++ b/array_api_tests/array_helpers.py @@ -6,10 +6,10 @@ int64, uint8, uint16, uint32, uint64, float32, float64, nan, inf, pi, remainder, divide, isinf, negative, asarray) - # These are exported here so that they can be included in the special cases # tests from this file. from ._array_module import logical_not, subtract, floor, ceil, where +from . import dtype_helpers as dh __all__ = ['all', 'any', 'logical_and', 'logical_or', 'logical_not', 'less', @@ -24,8 +24,7 @@ 'assert_positive_mathematical_sign', 'negative_mathematical_sign', 'assert_negative_mathematical_sign', 'same_sign', 'assert_same_sign', 'ndindex', 'float64', - 'asarray', 'is_integer_dtype', 'is_float_dtype', - 'full', 'true', 'false', 'isnan'] + 'asarray', 'full', 'true', 'false', 'isnan'] def zero(shape, dtype): """ @@ -109,7 +108,7 @@ def isnegzero(x): # TODO: If copysign or signbit are added to the spec, use those instead. shape = x.shape dtype = x.dtype - if is_integer_dtype(dtype): + if dh.is_int_dtype(dtype): return false(shape) return equal(divide(one(shape, dtype), x), -infinity(shape, dtype)) @@ -120,7 +119,7 @@ def isposzero(x): # TODO: If copysign or signbit are added to the spec, use those instead. shape = x.shape dtype = x.dtype - if is_integer_dtype(dtype): + if dh.is_int_dtype(dtype): return true(shape) return equal(divide(one(shape, dtype), x), infinity(shape, dtype)) @@ -309,19 +308,6 @@ def same_sign(x, y): def assert_same_sign(x, y): assert all(same_sign(x, y)), "The input arrays do not have the same sign" -def is_integer_dtype(dtype): - if dtype is None: - return False - return dtype in [int8, int16, int32, int64, uint8, uint16, uint32, uint64] - -def is_float_dtype(dtype): - if dtype is None: - # numpy.dtype('float64') == None gives True - return False - # TODO: Return True even for floating point dtypes that aren't part of the - # spec, like np.float16 - return dtype in [float32, float64] - def int_to_dtype(x, n, signed): """ Convert the Python integer x into an n bit signed or unsigned number. diff --git a/array_api_tests/dtype_helpers.py b/array_api_tests/dtype_helpers.py index 2b8c5c11..05c40a7b 100644 --- a/array_api_tests/dtype_helpers.py +++ b/array_api_tests/dtype_helpers.py @@ -10,6 +10,8 @@ 'all_dtypes', 'bool_and_all_int_dtypes', 'dtypes_to_scalars', + 'is_int_dtype', + 'is_float_dtype', 'dtype_ranges', 'category_to_dtypes', 'promotion_table', @@ -41,6 +43,21 @@ } +def is_int_dtype(dtype): + return dtype in all_int_dtypes + + +def is_float_dtype(dtype): + # None equals NumPy's xp.float64 object, so we specifically check it here. + # xp.float64 is in fact an alias of np.dtype('float64'), and its equality + # with None is meant to be deprecated at some point. + # See https://github.com/numpy/numpy/issues/18434 + if dtype is None: + return False + # TODO: Return True for float dtypes that aren't part of the spec e.g. np.float16 + return dtype in float_dtypes + + dtype_ranges = { xp.int8: [-128, +127], xp.int16: [-32_768, +32_767], diff --git a/array_api_tests/test_creation_functions.py b/array_api_tests/test_creation_functions.py index 690e2182..302bef49 100644 --- a/array_api_tests/test_creation_functions.py +++ b/array_api_tests/test_creation_functions.py @@ -4,12 +4,11 @@ full_like, equal, all, linspace, ones, ones_like, zeros, zeros_like, isnan) from . import _array_module as xp -from .array_helpers import (is_integer_dtype, - assert_exactly_equal, isintegral, is_float_dtype) +from .array_helpers import assert_exactly_equal, isintegral from .hypothesis_helpers import (numeric_dtypes, dtypes, MAX_ARRAY_SIZE, shapes, sizes, sqrt_sizes, shared_dtypes, scalars, kwargs) -from . dtype_helpers import dtype_ranges +from . import dtype_helpers as dh from . import xps from hypothesis import assume, given @@ -25,8 +24,8 @@ and (abs(x) > 0.01 if isinstance(x, float) else True)), one_of(none(), numeric_dtypes)) def test_arange(start, stop, step, dtype): - if dtype in dtype_ranges: - m, M = dtype_ranges[dtype] + if dtype in dh.dtype_ranges: + m, M = dh.dtype_ranges[dtype] if (not (m <= start <= M) or isinstance(stop, int) and not (m <= stop <= M) or isinstance(step, int) and not (m <= step <= M)): @@ -34,7 +33,7 @@ def test_arange(start, stop, step, dtype): kwargs = {} if dtype is None else {'dtype': dtype} - all_int = (is_integer_dtype(dtype) + all_int = (dh.is_int_dtype(dtype) and isinstance(start, int) and (stop is None or isinstance(stop, int)) and (step is None or isinstance(step, int))) @@ -76,7 +75,7 @@ def test_empty(shape, kw): out = empty(shape, **kw) dtype = kw.get("dtype", None) or xp.float64 if kw.get("dtype", None) is None: - assert is_float_dtype(out.dtype), f"empty() returned an array with dtype {out.dtype}, but should be the default float dtype" + assert dh.is_float_dtype(out.dtype), f"empty() returned an array with dtype {out.dtype}, but should be the default float dtype" else: assert out.dtype == dtype, f"{dtype=!s}, but empty() returned an array with dtype {out.dtype}" if isinstance(shape, int): @@ -111,7 +110,7 @@ def test_eye(n_rows, n_cols, k, dtype): else: a = eye(n_rows, n_cols, **kwargs) if dtype is None: - assert is_float_dtype(a.dtype), "eye() should return an array with the default floating point dtype" + assert dh.is_float_dtype(a.dtype), "eye() should return an array with the default floating point dtype" else: assert a.dtype == dtype, "eye() did not produce the correct dtype" @@ -153,7 +152,7 @@ def test_full(shape, fill_value, kw): dtype = xp.float64 if kw.get("dtype", None) is None: if dtype == xp.float64: - assert is_float_dtype(out.dtype), f"full() returned an array with dtype {out.dtype}, but should be the default float dtype" + assert dh.is_float_dtype(out.dtype), f"full() returned an array with dtype {out.dtype}, but should be the default float dtype" elif dtype == xp.int64: assert out.dtype == xp.int32 or out.dtype == xp.int64, f"full() returned an array with dtype {out.dtype}, but should be the default integer dtype" else: @@ -161,7 +160,7 @@ def test_full(shape, fill_value, kw): else: assert out.dtype == dtype assert out.shape == shape, f"{shape=}, but full() returned an array with shape {out.shape}" - if is_float_dtype(out.dtype) and math.isnan(fill_value): + if dh.is_float_dtype(out.dtype) and math.isnan(fill_value): assert all(isnan(out)), "full() array did not equal the fill value" else: assert all(equal(out, asarray(fill_value, dtype=dtype))), "full() array did not equal the fill value" @@ -187,7 +186,7 @@ def test_full_like(x, fill_value, kw): else: assert out.dtype == dtype, f"{dtype=!s}, but full_like() returned an array with dtype {out.dtype}" assert out.shape == x.shape, "{x.shape=}, but full_like() returned an array with shape {out.shape}" - if is_float_dtype(dtype) and math.isnan(fill_value): + if dh.is_float_dtype(dtype) and math.isnan(fill_value): assert all(isnan(out)), "full_like() array did not equal the fill value" else: assert all(equal(out, asarray(fill_value, dtype=dtype))), "full_like() array did not equal the fill value" @@ -201,7 +200,7 @@ def test_full_like(x, fill_value, kw): def test_linspace(start, stop, num, dtype, endpoint): # Skip on int start or stop that cannot be exactly represented as a float, # since we do not have good approx_equal helpers yet. - if ((dtype is None or is_float_dtype(dtype)) + if ((dtype is None or dh.is_float_dtype(dtype)) and ((isinstance(start, int) and not isintegral(asarray(start, dtype=dtype))) or (isinstance(stop, int) and not isintegral(asarray(stop, dtype=dtype))))): assume(False) @@ -211,7 +210,7 @@ def test_linspace(start, stop, num, dtype, endpoint): a = linspace(start, stop, num, **kwargs) if dtype is None: - assert is_float_dtype(a.dtype), "linspace() should return an array with the default floating point dtype" + assert dh.is_float_dtype(a.dtype), "linspace() should return an array with the default floating point dtype" else: assert a.dtype == dtype, "linspace() did not produce the correct dtype" @@ -238,9 +237,9 @@ def test_linspace(start, stop, num, dtype, endpoint): def make_one(dtype): - if kwargs is None or is_float_dtype(dtype): + if kwargs is None or dh.is_float_dtype(dtype): return 1.0 - elif is_integer_dtype(dtype): + elif dh.is_int_dtype(dtype): return 1 else: return True @@ -251,7 +250,7 @@ def test_ones(shape, kw): out = ones(shape, **kw) dtype = kw.get("dtype", None) or xp.float64 if kw.get("dtype", None) is None: - assert is_float_dtype(out.dtype), f"ones() returned an array with dtype {out.dtype}, but should be the default float dtype" + assert dh.is_float_dtype(out.dtype), f"ones() returned an array with dtype {out.dtype}, but should be the default float dtype" else: assert out.dtype == dtype, f"{dtype=!s}, but ones() returned an array with dtype {out.dtype}" assert out.shape == shape, f"{shape=}, but empty() returned an array with shape {out.shape}" @@ -274,9 +273,9 @@ def test_ones_like(x, kw): def make_zero(dtype): - if is_float_dtype(dtype): + if dh.is_float_dtype(dtype): return 0.0 - elif is_integer_dtype(dtype): + elif dh.is_int_dtype(dtype): return 0 else: return False @@ -287,7 +286,7 @@ def test_zeros(shape, kw): out = zeros(shape, **kw) dtype = kw.get("dtype", None) or xp.float64 if kw.get("dtype", None) is None: - assert is_float_dtype(out.dtype), "zeros() returned an array with dtype {out.dtype}, but should be the default float dtype" + assert dh.is_float_dtype(out.dtype), "zeros() returned an array with dtype {out.dtype}, but should be the default float dtype" else: assert out.dtype == dtype, f"{dtype=!s}, but zeros() returned an array with dtype {out.dtype}" assert out.shape == shape, "zeros() produced an array with incorrect shape" diff --git a/array_api_tests/test_elementwise_functions.py b/array_api_tests/test_elementwise_functions.py index e4e2d5f8..ee714458 100644 --- a/array_api_tests/test_elementwise_functions.py +++ b/array_api_tests/test_elementwise_functions.py @@ -19,7 +19,6 @@ from . import hypothesis_helpers as hh from . import dtype_helpers as dh from . import xps -from .dtype_helpers import dtype_nbits, dtype_signed, promotion_table # We might as well use this implementation rather than requiring # mod.broadcast_shapes(). See test_equal() and others. from .test_broadcasting import broadcast_shapes @@ -46,13 +45,13 @@ def two_array_scalars(draw, dtype1, dtype2): # TODO: refactor this into dtype_helpers.py, see https://github.com/data-apis/array-api-tests/pull/26 def sanity_check(x1, x2): try: - promotion_table[x1.dtype, x2.dtype] + dh.promotion_table[x1.dtype, x2.dtype] except ValueError: raise RuntimeError("Error in test generation (probably a bug in the test suite") @given(xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes)) def test_abs(x): - if ah.is_integer_dtype(x.dtype): + if dh.is_int_dtype(x.dtype): minval = dh.dtype_ranges[x.dtype][0] if minval < 0: # abs of the smallest representable negative integer is not defined @@ -206,7 +205,7 @@ def test_bitwise_and(x1, x2): val2 = int(_x2[idx]) res = int(out[idx]) vals_and = val1 & val2 - vals_and = ah.int_to_dtype(vals_and, dtype_nbits[out.dtype], dtype_signed[out.dtype]) + vals_and = ah.int_to_dtype(vals_and, dh.dtype_nbits[out.dtype], dh.dtype_signed[out.dtype]) assert vals_and == res @given(*hh.two_mutual_arrays(dh.all_int_dtypes)) @@ -227,8 +226,8 @@ def test_bitwise_left_shift(x1, x2): val2 = int(_x2[idx]) res = int(out[idx]) # We avoid shifting very large ints - vals_shift = val1 << val2 if val2 < dtype_nbits[out.dtype] else 0 - vals_shift = ah.int_to_dtype(vals_shift, dtype_nbits[out.dtype], dtype_signed[out.dtype]) + vals_shift = val1 << val2 if val2 < dh.dtype_nbits[out.dtype] else 0 + vals_shift = ah.int_to_dtype(vals_shift, dh.dtype_nbits[out.dtype], dh.dtype_signed[out.dtype]) assert vals_shift == res @given(xps.arrays(dtype=hh.integer_or_boolean_dtypes, shape=hh.shapes)) @@ -245,7 +244,7 @@ def test_bitwise_invert(x): val = int(x[idx]) res = int(out[idx]) val_invert = ~val - val_invert = ah.int_to_dtype(val_invert, dtype_nbits[out.dtype], dtype_signed[out.dtype]) + val_invert = ah.int_to_dtype(val_invert, dh.dtype_nbits[out.dtype], dh.dtype_signed[out.dtype]) assert val_invert == res @given(*hh.two_mutual_arrays(dh.bool_and_all_int_dtypes)) @@ -272,7 +271,7 @@ def test_bitwise_or(x1, x2): val2 = int(_x2[idx]) res = int(out[idx]) vals_or = val1 | val2 - vals_or = ah.int_to_dtype(vals_or, dtype_nbits[out.dtype], dtype_signed[out.dtype]) + vals_or = ah.int_to_dtype(vals_or, dh.dtype_nbits[out.dtype], dh.dtype_signed[out.dtype]) assert vals_or == res @given(*hh.two_mutual_arrays(dh.all_int_dtypes)) @@ -293,7 +292,7 @@ def test_bitwise_right_shift(x1, x2): val2 = int(_x2[idx]) res = int(out[idx]) vals_shift = val1 >> val2 - vals_shift = ah.int_to_dtype(vals_shift, dtype_nbits[out.dtype], dtype_signed[out.dtype]) + vals_shift = ah.int_to_dtype(vals_shift, dh.dtype_nbits[out.dtype], dh.dtype_signed[out.dtype]) assert vals_shift == res @given(*hh.two_mutual_arrays(dh.bool_and_all_int_dtypes)) @@ -320,7 +319,7 @@ def test_bitwise_xor(x1, x2): val2 = int(_x2[idx]) res = int(out[idx]) vals_xor = val1 ^ val2 - vals_xor = ah.int_to_dtype(vals_xor, dtype_nbits[out.dtype], dtype_signed[out.dtype]) + vals_xor = ah.int_to_dtype(vals_xor, dh.dtype_nbits[out.dtype], dh.dtype_signed[out.dtype]) assert vals_xor == res @given(xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes)) @@ -397,13 +396,13 @@ def test_equal(x1, x2): # tested in that file, because doing so requires doing the consistency # check we do here rather than just checking the result dtype. - promoted_dtype = promotion_table[x1.dtype, x2.dtype] + promoted_dtype = dh.promotion_table[x1.dtype, x2.dtype] _x1 = ah.asarray(_x1, dtype=promoted_dtype) _x2 = ah.asarray(_x2, dtype=promoted_dtype) - if ah.is_integer_dtype(promoted_dtype): + if dh.is_int_dtype(promoted_dtype): scalar_func = int - elif ah.is_float_dtype(promoted_dtype): + elif dh.is_float_dtype(promoted_dtype): scalar_func = float else: scalar_func = bool @@ -451,7 +450,7 @@ def test_floor(x): @given(*hh.two_mutual_arrays(dh.numeric_dtypes)) def test_floor_divide(x1, x2): sanity_check(x1, x2) - if ah.is_integer_dtype(x1.dtype): + if dh.is_int_dtype(x1.dtype): # The spec does not specify the behavior for division by 0 for integer # dtypes. A library may choose to raise an exception in this case, so # we avoid passing it in entirely. @@ -483,13 +482,13 @@ def test_greater(x1, x2): _x1 = xp.broadcast_to(x1, shape) _x2 = xp.broadcast_to(x2, shape) - promoted_dtype = promotion_table[x1.dtype, x2.dtype] + promoted_dtype = dh.promotion_table[x1.dtype, x2.dtype] _x1 = ah.asarray(_x1, dtype=promoted_dtype) _x2 = ah.asarray(_x2, dtype=promoted_dtype) - if ah.is_integer_dtype(promoted_dtype): + if dh.is_int_dtype(promoted_dtype): scalar_func = int - elif ah.is_float_dtype(promoted_dtype): + elif dh.is_float_dtype(promoted_dtype): scalar_func = float else: scalar_func = bool @@ -512,13 +511,13 @@ def test_greater_equal(x1, x2): _x1 = xp.broadcast_to(x1, shape) _x2 = xp.broadcast_to(x2, shape) - promoted_dtype = promotion_table[x1.dtype, x2.dtype] + promoted_dtype = dh.promotion_table[x1.dtype, x2.dtype] _x1 = ah.asarray(_x1, dtype=promoted_dtype) _x2 = ah.asarray(_x2, dtype=promoted_dtype) - if ah.is_integer_dtype(promoted_dtype): + if dh.is_int_dtype(promoted_dtype): scalar_func = int - elif ah.is_float_dtype(promoted_dtype): + elif dh.is_float_dtype(promoted_dtype): scalar_func = float else: scalar_func = bool @@ -534,14 +533,14 @@ def test_greater_equal(x1, x2): def test_isfinite(x): a = ah.isfinite(x) TRUE = ah.true(x.shape) - if ah.is_integer_dtype(x.dtype): + if dh.is_int_dtype(x.dtype): ah.assert_exactly_equal(a, TRUE) # Test that isfinite, isinf, and isnan are self-consistent. inf = ah.logical_or(xp.isinf(x), ah.isnan(x)) ah.assert_exactly_equal(a, ah.logical_not(inf)) # Test the exact value by comparing to the math version - if ah.is_float_dtype(x.dtype): + if dh.is_float_dtype(x.dtype): for idx in ah.ndindex(x.shape): s = float(x[idx]) assert bool(a[idx]) == math.isfinite(s) @@ -550,13 +549,13 @@ def test_isfinite(x): def test_isinf(x): a = xp.isinf(x) FALSE = ah.false(x.shape) - if ah.is_integer_dtype(x.dtype): + if dh.is_int_dtype(x.dtype): ah.assert_exactly_equal(a, FALSE) finite_or_nan = ah.logical_or(ah.isfinite(x), ah.isnan(x)) ah.assert_exactly_equal(a, ah.logical_not(finite_or_nan)) # Test the exact value by comparing to the math version - if ah.is_float_dtype(x.dtype): + if dh.is_float_dtype(x.dtype): for idx in ah.ndindex(x.shape): s = float(x[idx]) assert bool(a[idx]) == math.isinf(s) @@ -565,13 +564,13 @@ def test_isinf(x): def test_isnan(x): a = ah.isnan(x) FALSE = ah.false(x.shape) - if ah.is_integer_dtype(x.dtype): + if dh.is_int_dtype(x.dtype): ah.assert_exactly_equal(a, FALSE) finite_or_inf = ah.logical_or(ah.isfinite(x), xp.isinf(x)) ah.assert_exactly_equal(a, ah.logical_not(finite_or_inf)) # Test the exact value by comparing to the math version - if ah.is_float_dtype(x.dtype): + if dh.is_float_dtype(x.dtype): for idx in ah.ndindex(x.shape): s = float(x[idx]) assert bool(a[idx]) == math.isnan(s) @@ -587,13 +586,13 @@ def test_less(x1, x2): _x1 = xp.broadcast_to(x1, shape) _x2 = xp.broadcast_to(x2, shape) - promoted_dtype = promotion_table[x1.dtype, x2.dtype] + promoted_dtype = dh.promotion_table[x1.dtype, x2.dtype] _x1 = ah.asarray(_x1, dtype=promoted_dtype) _x2 = ah.asarray(_x2, dtype=promoted_dtype) - if ah.is_integer_dtype(promoted_dtype): + if dh.is_int_dtype(promoted_dtype): scalar_func = int - elif ah.is_float_dtype(promoted_dtype): + elif dh.is_float_dtype(promoted_dtype): scalar_func = float else: scalar_func = bool @@ -616,13 +615,13 @@ def test_less_equal(x1, x2): _x1 = xp.broadcast_to(x1, shape) _x2 = xp.broadcast_to(x2, shape) - promoted_dtype = promotion_table[x1.dtype, x2.dtype] + promoted_dtype = dh.promotion_table[x1.dtype, x2.dtype] _x1 = ah.asarray(_x1, dtype=promoted_dtype) _x2 = ah.asarray(_x2, dtype=promoted_dtype) - if ah.is_integer_dtype(promoted_dtype): + if dh.is_int_dtype(promoted_dtype): scalar_func = int - elif ah.is_float_dtype(promoted_dtype): + elif dh.is_float_dtype(promoted_dtype): scalar_func = float else: scalar_func = bool @@ -749,7 +748,7 @@ def test_negative(x): ah.assert_exactly_equal(x, ah.negative(out)) mask = ah.isfinite(x) - if ah.is_integer_dtype(x.dtype): + if dh.is_int_dtype(x.dtype): minval = dh.dtype_ranges[x.dtype][0] if minval < 0: # negative of the smallest representable negative integer is not defined @@ -772,13 +771,13 @@ def test_not_equal(x1, x2): _x1 = xp.broadcast_to(x1, shape) _x2 = xp.broadcast_to(x2, shape) - promoted_dtype = promotion_table[x1.dtype, x2.dtype] + promoted_dtype = dh.promotion_table[x1.dtype, x2.dtype] _x1 = ah.asarray(_x1, dtype=promoted_dtype) _x2 = ah.asarray(_x2, dtype=promoted_dtype) - if ah.is_integer_dtype(promoted_dtype): + if dh.is_int_dtype(promoted_dtype): scalar_func = int - elif ah.is_float_dtype(promoted_dtype): + elif dh.is_float_dtype(promoted_dtype): scalar_func = float else: scalar_func = bool From 6976d1940a640086a87f355beec5ab8877e467b5 Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Wed, 13 Oct 2021 17:01:37 +0100 Subject: [PATCH 24/41] Rudimentary clean-up of scalar promotion test --- array_api_tests/test_type_promotion.py | 92 +++++++++++++++++++++----- 1 file changed, 77 insertions(+), 15 deletions(-) diff --git a/array_api_tests/test_type_promotion.py b/array_api_tests/test_type_promotion.py index e23fdc7d..1222a98a 100644 --- a/array_api_tests/test_type_promotion.py +++ b/array_api_tests/test_type_promotion.py @@ -2,7 +2,7 @@ https://data-apis.github.io/array-api/latest/API_specification/type_promotion.html """ from collections import defaultdict -from typing import Iterator, TypeVar, Tuple, Callable +from typing import Iterator, TypeVar, Tuple, Callable, Type import pytest from hypothesis import assume, given @@ -79,8 +79,7 @@ def test_func_returns_array_with_correct_dtype( ): if len(in_dtypes) == 1: x = data.draw( - xps.arrays(dtype=in_dtypes[0], shape=hh.shapes).filter(x_filter), - label='x', + xps.arrays(dtype=in_dtypes[0], shape=hh.shapes).filter(x_filter), label='x' ) out = func(x) else: @@ -90,8 +89,7 @@ def test_func_returns_array_with_correct_dtype( ) for i, (dtype, shape) in enumerate(zip(in_dtypes, shapes), 1): x = data.draw( - xps.arrays(dtype=dtype, shape=shape).filter(x_filter), - label=f'x{i}', + xps.arrays(dtype=dtype, shape=shape).filter(x_filter), label=f'x{i}' ) arrays.append(x) out = func(*arrays) @@ -148,8 +146,7 @@ def test_operator_returns_array_with_correct_dtype( ): if len(in_dtypes) == 1: x = data.draw( - xps.arrays(dtype=in_dtypes[0], shape=hh.shapes).filter(x_filter), - label='x', + xps.arrays(dtype=in_dtypes[0], shape=hh.shapes).filter(x_filter), label='x' ) out = eval(expr, {'x': x}) else: @@ -159,8 +156,7 @@ def test_operator_returns_array_with_correct_dtype( ) for i, (dtype, shape) in enumerate(zip(in_dtypes, shapes), 1): locals_[f'x{i}'] = data.draw( - xps.arrays(dtype=dtype, shape=shape).filter(x_filter), - label=f'x{i}', + xps.arrays(dtype=dtype, shape=shape).filter(x_filter), label=f'x{i}' ) out = eval(expr, locals_) assert out.dtype == out_dtype, f'{out.dtype=!s}, but should be {out_dtype}' @@ -172,19 +168,19 @@ def gen_inplace_params() -> Iterator[Tuple[str, Tuple[DT, ...], DT, Callable]]: continue in_category = dh.op_in_categories[op] valid_in_dtypes = dh.category_to_dtypes[in_category] + iop = f'__i{op[2:]}' for (in_dtype1, in_dtype2), promoted_dtype in dh.promotion_table.items(): if ( in_dtype1 == promoted_dtype and in_dtype1 in valid_in_dtypes and in_dtype2 in valid_in_dtypes ): - iop = f'__i{op[2:]}' yield pytest.param( f'x1 {symbol}= x2', (in_dtype1, in_dtype2), promoted_dtype, filters[iop], - id=f'{iop}({in_dtype1}, {in_dtype2}) -> {in_dtype1}', + id=f'{iop}({in_dtype1}, {in_dtype2}) -> {promoted_dtype}', ) @@ -195,12 +191,10 @@ def test_inplace_operator_returns_array_with_correct_dtype( ): assume(len(shapes[0]) >= len(shapes[1])) x1 = data.draw( - xps.arrays(dtype=in_dtypes[0], shape=shapes[0]).filter(x_filter), - label='x1', + xps.arrays(dtype=in_dtypes[0], shape=shapes[0]).filter(x_filter), label='x1' ) x2 = data.draw( - xps.arrays(dtype=in_dtypes[1], shape=shapes[1]).filter(x_filter), - label='x2', + xps.arrays(dtype=in_dtypes[1], shape=shapes[1]).filter(x_filter), label='x2' ) locals_ = {'x1': x1, 'x2': x2} exec(expr, locals_) @@ -208,6 +202,74 @@ def test_inplace_operator_returns_array_with_correct_dtype( assert x1.dtype == out_dtype, f'{x1.dtype=!s}, but should be {out_dtype}' +def gen_op_scalar_params() -> Iterator[Tuple[str, DT, Type[float], DT, Callable]]: + for op, symbol in dh.binary_op_to_symbol.items(): + if op == '__matmul__': + continue + in_category = dh.op_in_categories[op] + out_category = dh.op_out_categories[op] + for in_dtype in dh.category_to_dtypes[in_category]: + out_dtype = in_dtype if out_category == 'promoted' else xp.bool + for in_stype in dh.dtypes_to_scalars[in_dtype]: + yield pytest.param( + f'x {symbol} s', + in_dtype, + in_stype, + out_dtype, + filters[op], + id=f'{op}({in_dtype}, {in_stype.__name__}) -> {out_dtype}', + ) + + +@pytest.mark.parametrize( + 'expr, in_dtype, in_stype, out_dtype, x_filter', gen_op_scalar_params() +) +@given(data=st.data()) +def test_binary_operator_promotes_python_scalars( + expr, in_dtype, in_stype, out_dtype, x_filter, data +): + # TODO: do not trigger undefined behaviours (overflows, infs, nans) + kw = {} if in_stype is float else {'allow_nan': False, 'allow_infinity': False} + s = data.draw(xps.from_dtype(in_dtype, **kw).map(in_stype), label=f'scalar') + x = data.draw( + xps.arrays(dtype=in_dtype, shape=hh.shapes).filter(x_filter), label='x' + ) + out = eval(expr, {'x': x, 's': s}) + assert out.dtype == out_dtype, f'{out.dtype=!s}, but should be {out_dtype}' + + +def gen_inplace_scalar_params() -> Iterator[Tuple[str, DT, Type[float], Callable]]: + for op, symbol in dh.binary_op_to_symbol.items(): + if op == '__matmul__' or dh.op_out_categories[op] == 'bool': + continue + in_category = dh.op_in_categories[op] + iop = f'__i{op[2:]}' + for dtype in dh.category_to_dtypes[in_category]: + for in_stype in dh.dtypes_to_scalars[dtype]: + yield pytest.param( + f'x {symbol}= s', + dtype, + in_stype, + filters[iop], + id=f'{iop}({dtype}, {in_stype.__name__}) -> {dtype}', + ) + + +@pytest.mark.parametrize('expr, dtype, in_stype, x_filter', gen_inplace_scalar_params()) +@given(data=st.data()) +def test_inplace_operator_promotes_python_scalars( + expr, dtype, in_stype, x_filter, data +): + # TODO: do not trigger undefined behaviours (overflows, infs, nans) + kw = {} if in_stype is float else {'allow_nan': False, 'allow_infinity': False} + s = data.draw(xps.from_dtype(dtype, **kw).map(in_stype), label=f'scalar') + x = data.draw(xps.arrays(dtype=dtype, shape=hh.shapes).filter(x_filter), label='x') + locals_ = {'x': x, 's': s} + exec(expr, locals_) + x = locals_['x'] + assert x.dtype == dtype, f'{x.dtype=!s}, but should be {dtype}' + + scalar_promotion_parametrize_inputs = [ pytest.param(func, dtype, scalar_type, id=f'{func}-{dtype}-{scalar_type.__name__}') for func in sorted(set(dh.binary_op_to_symbol) - {'__matmul__'}) From b98997c537b17dcb55123fa0253faa0eb039de07 Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Wed, 13 Oct 2021 17:30:25 +0100 Subject: [PATCH 25/41] Ignore testing NaNs and infs in `test_type_promotion.py` --- array_api_tests/test_type_promotion.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/array_api_tests/test_type_promotion.py b/array_api_tests/test_type_promotion.py index 1222a98a..cfe21d0b 100644 --- a/array_api_tests/test_type_promotion.py +++ b/array_api_tests/test_type_promotion.py @@ -221,6 +221,10 @@ def gen_op_scalar_params() -> Iterator[Tuple[str, DT, Type[float], DT, Callable] ) +# We ignore generating NaNs and infs - they are tested in the special_cases directory +dtype_kw = {'allow_nan': False, 'allow_infinity': False} + + @pytest.mark.parametrize( 'expr, in_dtype, in_stype, out_dtype, x_filter', gen_op_scalar_params() ) @@ -228,9 +232,8 @@ def gen_op_scalar_params() -> Iterator[Tuple[str, DT, Type[float], DT, Callable] def test_binary_operator_promotes_python_scalars( expr, in_dtype, in_stype, out_dtype, x_filter, data ): - # TODO: do not trigger undefined behaviours (overflows, infs, nans) - kw = {} if in_stype is float else {'allow_nan': False, 'allow_infinity': False} - s = data.draw(xps.from_dtype(in_dtype, **kw).map(in_stype), label=f'scalar') + # TODO: do not trigger overflows + s = data.draw(xps.from_dtype(in_dtype, **dtype_kw).map(in_stype), label=f'scalar') x = data.draw( xps.arrays(dtype=in_dtype, shape=hh.shapes).filter(x_filter), label='x' ) @@ -260,9 +263,8 @@ def gen_inplace_scalar_params() -> Iterator[Tuple[str, DT, Type[float], Callable def test_inplace_operator_promotes_python_scalars( expr, dtype, in_stype, x_filter, data ): - # TODO: do not trigger undefined behaviours (overflows, infs, nans) - kw = {} if in_stype is float else {'allow_nan': False, 'allow_infinity': False} - s = data.draw(xps.from_dtype(dtype, **kw).map(in_stype), label=f'scalar') + # TODO: do not trigger overflows + s = data.draw(xps.from_dtype(dtype, **dtype_kw).map(in_stype), label=f'scalar') x = data.draw(xps.arrays(dtype=dtype, shape=hh.shapes).filter(x_filter), label='x') locals_ = {'x': x, 's': s} exec(expr, locals_) From 3fe4b0758dd48ffe1fc225e8236449b359caaa81 Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Wed, 13 Oct 2021 18:40:19 +0100 Subject: [PATCH 26/41] Rudimentary array element factories for scalar promotion tests --- array_api_tests/dtype_helpers.py | 23 +++++--- array_api_tests/test_type_promotion.py | 73 ++++++++++++++++++++------ 2 files changed, 71 insertions(+), 25 deletions(-) diff --git a/array_api_tests/dtype_helpers.py b/array_api_tests/dtype_helpers.py index 05c40a7b..77d3574f 100644 --- a/array_api_tests/dtype_helpers.py +++ b/array_api_tests/dtype_helpers.py @@ -1,3 +1,5 @@ +from typing import NamedTuple + from . import _array_module as xp @@ -58,15 +60,20 @@ def is_float_dtype(dtype): return dtype in float_dtypes +class MinMax(NamedTuple): + min: int + max: int + + dtype_ranges = { - xp.int8: [-128, +127], - xp.int16: [-32_768, +32_767], - xp.int32: [-2_147_483_648, +2_147_483_647], - xp.int64: [-9_223_372_036_854_775_808, +9_223_372_036_854_775_807], - xp.uint8: [0, +255], - xp.uint16: [0, +65_535], - xp.uint32: [0, +4_294_967_295], - xp.uint64: [0, +18_446_744_073_709_551_615], + xp.int8: MinMax(-128, +127), + xp.int16: MinMax(-32_768, +32_767), + xp.int32: MinMax(-2_147_483_648, +2_147_483_647), + xp.int64: MinMax(-9_223_372_036_854_775_808, +9_223_372_036_854_775_807), + xp.uint8: MinMax(0, +255), + xp.uint16: MinMax(0, +65_535), + xp.uint32: MinMax(0, +4_294_967_295), + xp.uint64: MinMax(0, +18_446_744_073_709_551_615), } diff --git a/array_api_tests/test_type_promotion.py b/array_api_tests/test_type_promotion.py index cfe21d0b..f7dae573 100644 --- a/array_api_tests/test_type_promotion.py +++ b/array_api_tests/test_type_promotion.py @@ -2,7 +2,7 @@ https://data-apis.github.io/array-api/latest/API_specification/type_promotion.html """ from collections import defaultdict -from typing import Iterator, TypeVar, Tuple, Callable, Type +from typing import Iterator, TypeVar, Tuple, Callable, Type, Union import pytest from hypothesis import assume, given @@ -202,7 +202,39 @@ def test_inplace_operator_returns_array_with_correct_dtype( assert x1.dtype == out_dtype, f'{x1.dtype=!s}, but should be {out_dtype}' -def gen_op_scalar_params() -> Iterator[Tuple[str, DT, Type[float], DT, Callable]]: +finite_kw = {'allow_nan': False, 'allow_infinity': False} + + +int_kw_factories = defaultdict( + lambda: lambda m, M: lambda s: {}, + { + '__add__': lambda m, M: lambda s: { + 'min_value': max(m - s, m), + 'max_value': min(M - s, M), + }, + '__sub__': lambda m, M: lambda s: { + 'min_value': max(m + s, m), + 'max_value': min(M + s, M), + }, + # TODO: cover all the ops which require element factories + }, +) + + +def make_elements_factory(op, dtype): + if dh.is_int_dtype(dtype): + m, M = dh.dtype_ranges[dtype] + return int_kw_factories[op](m, M) + else: + return lambda _: finite_kw + + +ScalarType = Union[Type[bool], Type[int], Type[float]] + + +def gen_op_scalar_params() -> Iterator[ + Tuple[str, DT, ScalarType, DT, Callable, Callable] +]: for op, symbol in dh.binary_op_to_symbol.items(): if op == '__matmul__': continue @@ -217,31 +249,32 @@ def gen_op_scalar_params() -> Iterator[Tuple[str, DT, Type[float], DT, Callable] in_stype, out_dtype, filters[op], + make_elements_factory(op, in_dtype), id=f'{op}({in_dtype}, {in_stype.__name__}) -> {out_dtype}', ) -# We ignore generating NaNs and infs - they are tested in the special_cases directory -dtype_kw = {'allow_nan': False, 'allow_infinity': False} - - @pytest.mark.parametrize( - 'expr, in_dtype, in_stype, out_dtype, x_filter', gen_op_scalar_params() + 'expr, in_dtype, in_stype, out_dtype, x_filter, elements_factory', + gen_op_scalar_params(), ) @given(data=st.data()) def test_binary_operator_promotes_python_scalars( - expr, in_dtype, in_stype, out_dtype, x_filter, data + expr, in_dtype, in_stype, out_dtype, x_filter, elements_factory, data ): - # TODO: do not trigger overflows - s = data.draw(xps.from_dtype(in_dtype, **dtype_kw).map(in_stype), label=f'scalar') + s = data.draw(xps.from_dtype(in_dtype, **finite_kw).map(in_stype), label='scalar') + elements = elements_factory(s) x = data.draw( - xps.arrays(dtype=in_dtype, shape=hh.shapes).filter(x_filter), label='x' + xps.arrays(dtype=in_dtype, shape=hh.shapes, elements=elements).filter(x_filter), + label='x', ) out = eval(expr, {'x': x, 's': s}) assert out.dtype == out_dtype, f'{out.dtype=!s}, but should be {out_dtype}' -def gen_inplace_scalar_params() -> Iterator[Tuple[str, DT, Type[float], Callable]]: +def gen_inplace_scalar_params() -> Iterator[ + Tuple[str, DT, ScalarType, Callable, Callable] +]: for op, symbol in dh.binary_op_to_symbol.items(): if op == '__matmul__' or dh.op_out_categories[op] == 'bool': continue @@ -254,18 +287,24 @@ def gen_inplace_scalar_params() -> Iterator[Tuple[str, DT, Type[float], Callable dtype, in_stype, filters[iop], + make_elements_factory(op, dtype), id=f'{iop}({dtype}, {in_stype.__name__}) -> {dtype}', ) -@pytest.mark.parametrize('expr, dtype, in_stype, x_filter', gen_inplace_scalar_params()) +@pytest.mark.parametrize( + 'expr, dtype, in_stype, x_filter, elements_factory', gen_inplace_scalar_params() +) @given(data=st.data()) def test_inplace_operator_promotes_python_scalars( - expr, dtype, in_stype, x_filter, data + expr, dtype, in_stype, x_filter, elements_factory, data ): - # TODO: do not trigger overflows - s = data.draw(xps.from_dtype(dtype, **dtype_kw).map(in_stype), label=f'scalar') - x = data.draw(xps.arrays(dtype=dtype, shape=hh.shapes).filter(x_filter), label='x') + s = data.draw(xps.from_dtype(dtype, **finite_kw).map(in_stype), label='scalar') + elements = elements_factory(s) + x = data.draw( + xps.arrays(dtype=dtype, shape=hh.shapes, elements=elements).filter(x_filter), + label='x', + ) locals_ = {'x': x, 's': s} exec(expr, locals_) x = locals_['x'] From 198df7ecce5700e13166f2ec42f4a8c2d5b53d94 Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Thu, 14 Oct 2021 09:47:32 +0100 Subject: [PATCH 27/41] Remove elements factory --- array_api_tests/test_type_promotion.py | 67 ++++++-------------------- 1 file changed, 16 insertions(+), 51 deletions(-) diff --git a/array_api_tests/test_type_promotion.py b/array_api_tests/test_type_promotion.py index f7dae573..f6139085 100644 --- a/array_api_tests/test_type_promotion.py +++ b/array_api_tests/test_type_promotion.py @@ -203,38 +203,10 @@ def test_inplace_operator_returns_array_with_correct_dtype( finite_kw = {'allow_nan': False, 'allow_infinity': False} - - -int_kw_factories = defaultdict( - lambda: lambda m, M: lambda s: {}, - { - '__add__': lambda m, M: lambda s: { - 'min_value': max(m - s, m), - 'max_value': min(M - s, M), - }, - '__sub__': lambda m, M: lambda s: { - 'min_value': max(m + s, m), - 'max_value': min(M + s, M), - }, - # TODO: cover all the ops which require element factories - }, -) - - -def make_elements_factory(op, dtype): - if dh.is_int_dtype(dtype): - m, M = dh.dtype_ranges[dtype] - return int_kw_factories[op](m, M) - else: - return lambda _: finite_kw - - ScalarType = Union[Type[bool], Type[int], Type[float]] -def gen_op_scalar_params() -> Iterator[ - Tuple[str, DT, ScalarType, DT, Callable, Callable] -]: +def gen_op_scalar_params() -> Iterator[Tuple[str, DT, ScalarType, DT, Callable]]: for op, symbol in dh.binary_op_to_symbol.items(): if op == '__matmul__': continue @@ -249,32 +221,29 @@ def gen_op_scalar_params() -> Iterator[ in_stype, out_dtype, filters[op], - make_elements_factory(op, in_dtype), id=f'{op}({in_dtype}, {in_stype.__name__}) -> {out_dtype}', ) @pytest.mark.parametrize( - 'expr, in_dtype, in_stype, out_dtype, x_filter, elements_factory', - gen_op_scalar_params(), + 'expr, in_dtype, in_stype, out_dtype, x_filter', gen_op_scalar_params() ) @given(data=st.data()) def test_binary_operator_promotes_python_scalars( - expr, in_dtype, in_stype, out_dtype, x_filter, elements_factory, data + expr, in_dtype, in_stype, out_dtype, x_filter, data ): s = data.draw(xps.from_dtype(in_dtype, **finite_kw).map(in_stype), label='scalar') - elements = elements_factory(s) x = data.draw( - xps.arrays(dtype=in_dtype, shape=hh.shapes, elements=elements).filter(x_filter), - label='x', + xps.arrays(dtype=in_dtype, shape=hh.shapes).filter(x_filter), label='x' ) - out = eval(expr, {'x': x, 's': s}) + try: + out = eval(expr, {'x': x, 's': s}) + except OverflowError: + assume(False) assert out.dtype == out_dtype, f'{out.dtype=!s}, but should be {out_dtype}' -def gen_inplace_scalar_params() -> Iterator[ - Tuple[str, DT, ScalarType, Callable, Callable] -]: +def gen_inplace_scalar_params() -> Iterator[Tuple[str, DT, ScalarType, Callable]]: for op, symbol in dh.binary_op_to_symbol.items(): if op == '__matmul__' or dh.op_out_categories[op] == 'bool': continue @@ -287,26 +256,22 @@ def gen_inplace_scalar_params() -> Iterator[ dtype, in_stype, filters[iop], - make_elements_factory(op, dtype), id=f'{iop}({dtype}, {in_stype.__name__}) -> {dtype}', ) -@pytest.mark.parametrize( - 'expr, dtype, in_stype, x_filter, elements_factory', gen_inplace_scalar_params() -) +@pytest.mark.parametrize('expr, dtype, in_stype, x_filter', gen_inplace_scalar_params()) @given(data=st.data()) def test_inplace_operator_promotes_python_scalars( - expr, dtype, in_stype, x_filter, elements_factory, data + expr, dtype, in_stype, x_filter, data ): s = data.draw(xps.from_dtype(dtype, **finite_kw).map(in_stype), label='scalar') - elements = elements_factory(s) - x = data.draw( - xps.arrays(dtype=dtype, shape=hh.shapes, elements=elements).filter(x_filter), - label='x', - ) + x = data.draw(xps.arrays(dtype=dtype, shape=hh.shapes).filter(x_filter), label='x') locals_ = {'x': x, 's': s} - exec(expr, locals_) + try: + exec(expr, locals_) + except OverflowError: + assume(False) x = locals_['x'] assert x.dtype == dtype, f'{x.dtype=!s}, but should be {dtype}' From e40cac8f80feb918a96e3acc87d53d5d41e5f86c Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Thu, 14 Oct 2021 10:05:10 +0100 Subject: [PATCH 28/41] Reject on OverflowError --- array_api_tests/test_type_promotion.py | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/array_api_tests/test_type_promotion.py b/array_api_tests/test_type_promotion.py index f6139085..324755b7 100644 --- a/array_api_tests/test_type_promotion.py +++ b/array_api_tests/test_type_promotion.py @@ -5,7 +5,7 @@ from typing import Iterator, TypeVar, Tuple, Callable, Type, Union import pytest -from hypothesis import assume, given +from hypothesis import assume, given, reject from hypothesis import strategies as st from . import _array_module as xp @@ -81,7 +81,7 @@ def test_func_returns_array_with_correct_dtype( x = data.draw( xps.arrays(dtype=in_dtypes[0], shape=hh.shapes).filter(x_filter), label='x' ) - out = func(x) + arrays = [x] else: arrays = [] shapes = data.draw( @@ -92,7 +92,10 @@ def test_func_returns_array_with_correct_dtype( xps.arrays(dtype=dtype, shape=shape).filter(x_filter), label=f'x{i}' ) arrays.append(x) + try: out = func(*arrays) + except OverflowError: + reject() assert out.dtype == out_dtype, f'{out.dtype=!s}, but should be {out_dtype}' @@ -144,13 +147,12 @@ def gen_op_params() -> Iterator[Tuple[str, Tuple[DT, ...], DT, Callable]]: def test_operator_returns_array_with_correct_dtype( expr, in_dtypes, out_dtype, x_filter, data ): + locals_ = {} if len(in_dtypes) == 1: - x = data.draw( + locals_['x'] = data.draw( xps.arrays(dtype=in_dtypes[0], shape=hh.shapes).filter(x_filter), label='x' ) - out = eval(expr, {'x': x}) else: - locals_ = {} shapes = data.draw( hh.mutually_broadcastable_shapes(len(in_dtypes)), label='shapes' ) @@ -158,7 +160,10 @@ def test_operator_returns_array_with_correct_dtype( locals_[f'x{i}'] = data.draw( xps.arrays(dtype=dtype, shape=shape).filter(x_filter), label=f'x{i}' ) + try: out = eval(expr, locals_) + except OverflowError: + reject() assert out.dtype == out_dtype, f'{out.dtype=!s}, but should be {out_dtype}' @@ -197,7 +202,10 @@ def test_inplace_operator_returns_array_with_correct_dtype( xps.arrays(dtype=in_dtypes[1], shape=shapes[1]).filter(x_filter), label='x2' ) locals_ = {'x1': x1, 'x2': x2} - exec(expr, locals_) + try: + exec(expr, locals_) + except OverflowError: + reject() x1 = locals_['x1'] assert x1.dtype == out_dtype, f'{x1.dtype=!s}, but should be {out_dtype}' @@ -239,7 +247,7 @@ def test_binary_operator_promotes_python_scalars( try: out = eval(expr, {'x': x, 's': s}) except OverflowError: - assume(False) + reject() assert out.dtype == out_dtype, f'{out.dtype=!s}, but should be {out_dtype}' @@ -271,7 +279,7 @@ def test_inplace_operator_promotes_python_scalars( try: exec(expr, locals_) except OverflowError: - assume(False) + reject() x = locals_['x'] assert x.dtype == dtype, f'{x.dtype=!s}, but should be {dtype}' From b37ecc970194f03c3cb28023158aa0e343af58f7 Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Thu, 14 Oct 2021 10:06:12 +0100 Subject: [PATCH 29/41] Remove old scalar promotion tests --- array_api_tests/test_type_promotion.py | 68 -------------------------- 1 file changed, 68 deletions(-) diff --git a/array_api_tests/test_type_promotion.py b/array_api_tests/test_type_promotion.py index 324755b7..630736ae 100644 --- a/array_api_tests/test_type_promotion.py +++ b/array_api_tests/test_type_promotion.py @@ -284,74 +284,6 @@ def test_inplace_operator_promotes_python_scalars( assert x.dtype == dtype, f'{x.dtype=!s}, but should be {dtype}' -scalar_promotion_parametrize_inputs = [ - pytest.param(func, dtype, scalar_type, id=f'{func}-{dtype}-{scalar_type.__name__}') - for func in sorted(set(dh.binary_op_to_symbol) - {'__matmul__'}) - for dtype in dh.category_to_dtypes[dh.op_in_categories[func]] - for scalar_type in dh.dtypes_to_scalars[dtype] -] - - -@pytest.mark.parametrize('func,dtype,scalar_type', scalar_promotion_parametrize_inputs) -@given(shape=hh.shapes, python_scalars=st.data(), data=st.data()) -def test_operator_scalar_arg_return_promoted( - func, dtype, scalar_type, shape, python_scalars, data -): - """ - See https://data-apis.github.io/array-api/latest/API_specification/type_promotion.html#mixing-arrays-with-python-scalars - """ - op = dh.binary_op_to_symbol[func] - if op == '@': - pytest.skip('matmul (@) is not supported for hh.scalars') - - if dtype in dh.category_to_dtypes['integer']: - s = python_scalars.draw(st.integers(*dh.dtype_ranges[dtype])) - else: - s = python_scalars.draw(st.from_type(scalar_type)) - scalar_as_array = ah.asarray(s, dtype=dtype) - get_locals = lambda: dict(a=a, s=s, scalar_as_array=scalar_as_array) - - fillvalue = data.draw(hh.scalars(st.just(dtype))) - a = ah.full(shape, fillvalue, dtype=dtype) - - # As per the spec: - - # The expected behavior is then equivalent to: - # - # 1. Convert the scalar to a 0-D array with the same dtype as that of the - # array used in the expression. - # - # 2. Execute the operation for `array 0-D array` (or `0-D array - # array` if `scalar` was the left-hand argument). - - array_scalar = f'a {op} s' - array_scalar_expected = f'a {op} scalar_as_array' - res = eval(array_scalar, get_locals()) - expected = eval(array_scalar_expected, get_locals()) - ah.assert_exactly_equal(res, expected) - - scalar_array = f's {op} a' - scalar_array_expected = f'scalar_as_array {op} a' - res = eval(scalar_array, get_locals()) - expected = eval(scalar_array_expected, get_locals()) - ah.assert_exactly_equal(res, expected) - - # Test in-place operators - if op in ['==', '!=', '<', '>', '<=', '>=']: - return - array_scalar = f'a {op}= s' - array_scalar_expected = f'a {op}= scalar_as_array' - a = ah.full(shape, fillvalue, dtype=dtype) - res_locals = get_locals() - exec(array_scalar, get_locals()) - res = res_locals['a'] - a = ah.full(shape, fillvalue, dtype=dtype) - expected_locals = get_locals() - exec(array_scalar_expected, get_locals()) - expected = expected_locals['a'] - ah.assert_exactly_equal(res, expected) - - if __name__ == '__main__': for (i, j), p in dh.promotion_table.items(): print(f'({i}, {j}) -> {p}') From 06fc784b80357110d2b398794bf461f4f5382615 Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Thu, 14 Oct 2021 10:26:45 +0100 Subject: [PATCH 30/41] Construct `inplace_op_to_symbol` in `dtype_helpers` --- array_api_tests/dtype_helpers.py | 19 +++++++++++++++---- array_api_tests/test_type_promotion.py | 22 ++++++++++------------ 2 files changed, 25 insertions(+), 16 deletions(-) diff --git a/array_api_tests/dtype_helpers.py b/array_api_tests/dtype_helpers.py index 77d3574f..1422111f 100644 --- a/array_api_tests/dtype_helpers.py +++ b/array_api_tests/dtype_helpers.py @@ -23,9 +23,10 @@ 'func_out_categories', 'op_in_categories', 'op_out_categories', + 'op_to_func', 'binary_op_to_symbol', 'unary_op_to_symbol', - 'op_to_func', + 'inplace_op_to_symbol', ] @@ -328,6 +329,16 @@ class MinMax(NamedTuple): op_in_categories = {} op_out_categories = {} -for op_func, elwise_func in op_to_func.items(): - op_in_categories[op_func] = func_in_categories[elwise_func] - op_out_categories[op_func] = func_out_categories[elwise_func] +for op, elwise_func in op_to_func.items(): + op_in_categories[op] = func_in_categories[elwise_func] + op_out_categories[op] = func_out_categories[elwise_func] + + +inplace_op_to_symbol = {} +for op, symbol in binary_op_to_symbol.items(): + if op == '__matmul__' or op_out_categories[op] == 'bool': + continue + iop = f'__i{op[2:]}' + inplace_op_to_symbol[iop] = f'{symbol}=' + op_in_categories[iop] = op_in_categories[op] + op_out_categories[iop] = op_out_categories[op] diff --git a/array_api_tests/test_type_promotion.py b/array_api_tests/test_type_promotion.py index 630736ae..42292a1d 100644 --- a/array_api_tests/test_type_promotion.py +++ b/array_api_tests/test_type_promotion.py @@ -168,12 +168,11 @@ def test_operator_returns_array_with_correct_dtype( def gen_inplace_params() -> Iterator[Tuple[str, Tuple[DT, ...], DT, Callable]]: - for op, symbol in dh.binary_op_to_symbol.items(): - if op == '__matmul__' or dh.op_out_categories[op] == 'bool': + for op, symbol in dh.inplace_op_to_symbol.items(): + if op == '__imatmul__': continue in_category = dh.op_in_categories[op] valid_in_dtypes = dh.category_to_dtypes[in_category] - iop = f'__i{op[2:]}' for (in_dtype1, in_dtype2), promoted_dtype in dh.promotion_table.items(): if ( in_dtype1 == promoted_dtype @@ -181,11 +180,11 @@ def gen_inplace_params() -> Iterator[Tuple[str, Tuple[DT, ...], DT, Callable]]: and in_dtype2 in valid_in_dtypes ): yield pytest.param( - f'x1 {symbol}= x2', + f'x1 {symbol} x2', (in_dtype1, in_dtype2), promoted_dtype, - filters[iop], - id=f'{iop}({in_dtype1}, {in_dtype2}) -> {promoted_dtype}', + filters[op], + id=f'{op}({in_dtype1}, {in_dtype2}) -> {promoted_dtype}', ) @@ -252,19 +251,18 @@ def test_binary_operator_promotes_python_scalars( def gen_inplace_scalar_params() -> Iterator[Tuple[str, DT, ScalarType, Callable]]: - for op, symbol in dh.binary_op_to_symbol.items(): - if op == '__matmul__' or dh.op_out_categories[op] == 'bool': + for op, symbol in dh.inplace_op_to_symbol.items(): + if op == '__imatmul__': continue in_category = dh.op_in_categories[op] - iop = f'__i{op[2:]}' for dtype in dh.category_to_dtypes[in_category]: for in_stype in dh.dtypes_to_scalars[dtype]: yield pytest.param( - f'x {symbol}= s', + f'x {symbol} s', dtype, in_stype, - filters[iop], - id=f'{iop}({dtype}, {in_stype.__name__}) -> {dtype}', + filters[op], + id=f'{op}({dtype}, {in_stype.__name__}) -> {dtype}', ) From eab08d984427854fce4065ffaedb5914e5757c93 Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Thu, 14 Oct 2021 10:38:30 +0100 Subject: [PATCH 31/41] Generate NaNs and infs for scalar promotion tests --- array_api_tests/test_type_promotion.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/array_api_tests/test_type_promotion.py b/array_api_tests/test_type_promotion.py index 42292a1d..ef201e2c 100644 --- a/array_api_tests/test_type_promotion.py +++ b/array_api_tests/test_type_promotion.py @@ -209,7 +209,6 @@ def test_inplace_operator_returns_array_with_correct_dtype( assert x1.dtype == out_dtype, f'{x1.dtype=!s}, but should be {out_dtype}' -finite_kw = {'allow_nan': False, 'allow_infinity': False} ScalarType = Union[Type[bool], Type[int], Type[float]] @@ -239,7 +238,8 @@ def gen_op_scalar_params() -> Iterator[Tuple[str, DT, ScalarType, DT, Callable]] def test_binary_operator_promotes_python_scalars( expr, in_dtype, in_stype, out_dtype, x_filter, data ): - s = data.draw(xps.from_dtype(in_dtype, **finite_kw).map(in_stype), label='scalar') + kw = {k: in_stype is float for k in ('allow_nan', 'allow_infinity')} + s = data.draw(xps.from_dtype(in_dtype, **kw).map(in_stype), label='scalar') x = data.draw( xps.arrays(dtype=in_dtype, shape=hh.shapes).filter(x_filter), label='x' ) @@ -271,7 +271,8 @@ def gen_inplace_scalar_params() -> Iterator[Tuple[str, DT, ScalarType, Callable] def test_inplace_operator_promotes_python_scalars( expr, dtype, in_stype, x_filter, data ): - s = data.draw(xps.from_dtype(dtype, **finite_kw).map(in_stype), label='scalar') + kw = {k: in_stype is float for k in ('allow_nan', 'allow_infinity')} + s = data.draw(xps.from_dtype(dtype, **kw).map(in_stype), label='scalar') x = data.draw(xps.arrays(dtype=dtype, shape=hh.shapes).filter(x_filter), label='x') locals_ = {'x': x, 's': s} try: From 4907178c9ae39c9d9c46cc5022ea87d52fbe62c5 Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Thu, 14 Oct 2021 10:54:30 +0100 Subject: [PATCH 32/41] Dont catch OverflowError for single array funcs --- array_api_tests/test_type_promotion.py | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/array_api_tests/test_type_promotion.py b/array_api_tests/test_type_promotion.py index ef201e2c..aad348eb 100644 --- a/array_api_tests/test_type_promotion.py +++ b/array_api_tests/test_type_promotion.py @@ -81,7 +81,7 @@ def test_func_returns_array_with_correct_dtype( x = data.draw( xps.arrays(dtype=in_dtypes[0], shape=hh.shapes).filter(x_filter), label='x' ) - arrays = [x] + out = func(x) else: arrays = [] shapes = data.draw( @@ -92,10 +92,10 @@ def test_func_returns_array_with_correct_dtype( xps.arrays(dtype=dtype, shape=shape).filter(x_filter), label=f'x{i}' ) arrays.append(x) - try: - out = func(*arrays) - except OverflowError: - reject() + try: + out = func(*arrays) + except OverflowError: + reject() assert out.dtype == out_dtype, f'{out.dtype=!s}, but should be {out_dtype}' @@ -147,12 +147,13 @@ def gen_op_params() -> Iterator[Tuple[str, Tuple[DT, ...], DT, Callable]]: def test_operator_returns_array_with_correct_dtype( expr, in_dtypes, out_dtype, x_filter, data ): - locals_ = {} if len(in_dtypes) == 1: - locals_['x'] = data.draw( + x = data.draw( xps.arrays(dtype=in_dtypes[0], shape=hh.shapes).filter(x_filter), label='x' ) + out = eval(expr, {'x': x}) else: + locals_ = {} shapes = data.draw( hh.mutually_broadcastable_shapes(len(in_dtypes)), label='shapes' ) @@ -160,10 +161,10 @@ def test_operator_returns_array_with_correct_dtype( locals_[f'x{i}'] = data.draw( xps.arrays(dtype=dtype, shape=shape).filter(x_filter), label=f'x{i}' ) - try: - out = eval(expr, locals_) - except OverflowError: - reject() + try: + out = eval(expr, locals_) + except OverflowError: + reject() assert out.dtype == out_dtype, f'{out.dtype=!s}, but should be {out_dtype}' From 75f44d7ab61ad441ec270cb49aa1c4ac78d51589 Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Thu, 14 Oct 2021 10:56:52 +0100 Subject: [PATCH 33/41] Fix regression of `dtype_objects` instead of updated `dtype_objs` --- array_api_tests/test_linalg.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/array_api_tests/test_linalg.py b/array_api_tests/test_linalg.py index fe9a4fdf..47052105 100644 --- a/array_api_tests/test_linalg.py +++ b/array_api_tests/test_linalg.py @@ -339,7 +339,7 @@ def test_matrix_transpose(x): _test_stacks(linalg.matrix_transpose, x, res=res, true_val=true_val) @given( - *two_mutual_arrays(dtype_objects=dh.numeric_dtypes, + *two_mutual_arrays(dtype_objs=dh.numeric_dtypes, two_shapes=tuples(one_d_shapes, one_d_shapes)) ) def test_outer(x1, x2): From d4467f1e970d6243fb7c1cd95c7b9da8f4ccb1fe Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Thu, 14 Oct 2021 12:24:47 +0100 Subject: [PATCH 34/41] Sort `promotable_dtypes` without relying on dtype comparisons --- array_api_tests/hypothesis_helpers.py | 42 +++++++++++++++++---------- 1 file changed, 27 insertions(+), 15 deletions(-) diff --git a/array_api_tests/hypothesis_helpers.py b/array_api_tests/hypothesis_helpers.py index 3d2419d2..8122efff 100644 --- a/array_api_tests/hypothesis_helpers.py +++ b/array_api_tests/hypothesis_helpers.py @@ -47,27 +47,39 @@ shared_dtypes = shared(dtypes, key="dtype") shared_floating_dtypes = shared(floating_dtypes, key="dtype") +_dtype_categories = [(xp.bool,), dh.uint_dtypes, dh.int_dtypes, dh.float_dtypes] +_sorted_dtypes = [d for category in _dtype_categories for d in category] + +def _dtypes_sorter(dtype_pair): + dtype1, dtype2 = dtype_pair + if dtype1 == dtype2: + return _sorted_dtypes.index(dtype1) + key = len(_sorted_dtypes) + rank1 = _sorted_dtypes.index(dtype1) + rank2 = _sorted_dtypes.index(dtype2) + for category in _dtype_categories: + if dtype1 in category and dtype2 in category: + break + else: + key += len(_sorted_dtypes) ** 2 + key += 2 * (rank1 + rank2) + if rank1 > rank2: + key += 1 + return key + +promotable_dtypes = sorted(dh.promotion_table.keys(), key=_dtypes_sorter) -sorted_table = sorted(dh.promotion_table) -sorted_table = sorted( - sorted_table, key=lambda ij: -1 if ij[0] == ij[1] else sorted_table.index(ij) -) if FILTER_UNDEFINED_DTYPES: - sorted_table = [(i, j) for i, j in sorted_table - if not isinstance(i, _UndefinedStub) - and not isinstance(j, _UndefinedStub)] + promotable_dtypes = [ + (i, j) for i, j in promotable_dtypes + if not isinstance(i, _UndefinedStub) + and not isinstance(j, _UndefinedStub) + ] def mutually_promotable_dtypes(dtype_objs=dh.all_dtypes): - # sort for shrinking (sampled_from shrinks to the earlier elements in the - # list). Give pairs of the same dtypes first, then smaller dtypes, - # preferring float, then int, then unsigned int. Note, this might not - # always result in examples shrinking to these pairs because strategies - # that draw from dtypes might not draw the same example from different - # pairs (XXX: Can we redesign the strategies so that they can prefer - # shrinking dtypes over values?) return sampled_from( - [(i, j) for i, j in sorted_table if i in dtype_objs and j in dtype_objs] + [(i, j) for i, j in promotable_dtypes if i in dtype_objs and j in dtype_objs] ) # shared() allows us to draw either the function or the function name and they From 64ef13c7cf559c26bce4013b46fe3510f1614c8e Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Mon, 18 Oct 2021 10:39:44 +0100 Subject: [PATCH 35/41] Don't rely on str(dtype) for parameter ids --- array_api_tests/dtype_helpers.py | 20 +++++++++--- array_api_tests/test_type_promotion.py | 43 +++++++++++++++++--------- 2 files changed, 43 insertions(+), 20 deletions(-) diff --git a/array_api_tests/dtype_helpers.py b/array_api_tests/dtype_helpers.py index 1422111f..614bb403 100644 --- a/array_api_tests/dtype_helpers.py +++ b/array_api_tests/dtype_helpers.py @@ -10,8 +10,9 @@ 'float_dtypes', 'numeric_dtypes', 'all_dtypes', + 'dtype_to_name', 'bool_and_all_int_dtypes', - 'dtypes_to_scalars', + 'dtype_to_scalars', 'is_int_dtype', 'is_float_dtype', 'dtype_ranges', @@ -30,16 +31,25 @@ ] -int_dtypes = (xp.int8, xp.int16, xp.int32, xp.int64) -uint_dtypes = (xp.uint8, xp.uint16, xp.uint32, xp.uint64) +_int_names = ('int8', 'int16', 'int32', 'int64') +_uint_names = ('uint8', 'uint16', 'uint32', 'uint64') +_float_names = ('float32', 'float64') +_dtype_names = ('bool',) + _int_names + _uint_names + _float_names + + +int_dtypes = tuple(getattr(xp, name) for name in _int_names) +uint_dtypes = tuple(getattr(xp, name) for name in _uint_names) +float_dtypes = tuple(getattr(xp, name) for name in _float_names) all_int_dtypes = int_dtypes + uint_dtypes -float_dtypes = (xp.float32, xp.float64) numeric_dtypes = all_int_dtypes + float_dtypes all_dtypes = (xp.bool,) + numeric_dtypes bool_and_all_int_dtypes = (xp.bool,) + all_int_dtypes -dtypes_to_scalars = { +dtype_to_name = {getattr(xp, name): name for name in _dtype_names} + + +dtype_to_scalars = { xp.bool: [bool], **{d: [int] for d in all_int_dtypes}, **{d: [int, float] for d in float_dtypes}, diff --git a/array_api_tests/test_type_promotion.py b/array_api_tests/test_type_promotion.py index aad348eb..2b7d5e0b 100644 --- a/array_api_tests/test_type_promotion.py +++ b/array_api_tests/test_type_promotion.py @@ -2,7 +2,7 @@ https://data-apis.github.io/array-api/latest/API_specification/type_promotion.html """ from collections import defaultdict -from typing import Iterator, TypeVar, Tuple, Callable, Type, Union +from typing import Iterator, Tuple, Callable, Type, Union import pytest from hypothesis import assume, given, reject @@ -27,7 +27,8 @@ ] -DT = TypeVar('DT') +DT = Type +ScalarType = Union[Type[bool], Type[int], Type[float]] # We apply filters to xps.arrays() so we don't generate array elements that @@ -38,6 +39,21 @@ ) +def make_id( + func_name: str, in_dtypes: Tuple[Union[DT, ScalarType], ...], out_dtype: DT +) -> str: + f_in_dtypes = [] + for dtype in in_dtypes: + try: + f_in_dtypes.append(dh.dtype_to_name[dtype]) + except KeyError: + # i.e. dtype is bool, int, or float + f_in_dtypes.append(dtype.__name__) + f_args = ', '.join(f_in_dtypes) + f_out_dtype = dh.dtype_to_name[out_dtype] + return f'{func_name}({f_args}) -> {f_out_dtype}' + + def gen_func_params() -> Iterator[Tuple[Callable, Tuple[DT, ...], DT, Callable]]: for func_name in elementwise_functions.__all__: func = getattr(xp, func_name) @@ -53,7 +69,7 @@ def gen_func_params() -> Iterator[Tuple[Callable, Tuple[DT, ...], DT, Callable]] (in_dtype,), out_dtype, filters[func_name], - id=f'{func_name}({in_dtype}) -> {out_dtype}', + id=make_id(func_name, (in_dtype,), out_dtype), ) elif ndtypes == 2: for (in_dtype1, in_dtype2), promoted_dtype in dh.promotion_table.items(): @@ -66,7 +82,7 @@ def gen_func_params() -> Iterator[Tuple[Callable, Tuple[DT, ...], DT, Callable]] (in_dtype1, in_dtype2), out_dtype, filters[func_name], - id=f'{func_name}({in_dtype1}, {in_dtype2}) -> {out_dtype}', + id=make_id(func_name, (in_dtype1, in_dtype2), out_dtype), ) else: raise NotImplementedError() @@ -116,7 +132,7 @@ def gen_op_params() -> Iterator[Tuple[str, Tuple[DT, ...], DT, Callable]]: (in_dtype,), out_dtype, filters[op], - id=f'{op}({in_dtype}) -> {out_dtype}', + id=make_id(op, (in_dtype,), out_dtype), ) else: for (in_dtype1, in_dtype2), promoted_dtype in dh.promotion_table.items(): @@ -129,7 +145,7 @@ def gen_op_params() -> Iterator[Tuple[str, Tuple[DT, ...], DT, Callable]]: (in_dtype1, in_dtype2), out_dtype, filters[op], - id=f'{op}({in_dtype1}, {in_dtype2}) -> {out_dtype}', + id=make_id(op, (in_dtype1, in_dtype2), out_dtype), ) # We generate params for abs seperately as it does not have an associated symbol for in_dtype in dh.category_to_dtypes[dh.op_in_categories['__abs__']]: @@ -138,7 +154,7 @@ def gen_op_params() -> Iterator[Tuple[str, Tuple[DT, ...], DT, Callable]]: (in_dtype,), in_dtype, filters['__abs__'], - id=f'__abs__({in_dtype}) -> {in_dtype}', + id=make_id('__abs__', (in_dtype,), in_dtype), ) @@ -185,7 +201,7 @@ def gen_inplace_params() -> Iterator[Tuple[str, Tuple[DT, ...], DT, Callable]]: (in_dtype1, in_dtype2), promoted_dtype, filters[op], - id=f'{op}({in_dtype1}, {in_dtype2}) -> {promoted_dtype}', + id=make_id(op, (in_dtype1, in_dtype2), promoted_dtype), ) @@ -210,9 +226,6 @@ def test_inplace_operator_returns_array_with_correct_dtype( assert x1.dtype == out_dtype, f'{x1.dtype=!s}, but should be {out_dtype}' -ScalarType = Union[Type[bool], Type[int], Type[float]] - - def gen_op_scalar_params() -> Iterator[Tuple[str, DT, ScalarType, DT, Callable]]: for op, symbol in dh.binary_op_to_symbol.items(): if op == '__matmul__': @@ -221,14 +234,14 @@ def gen_op_scalar_params() -> Iterator[Tuple[str, DT, ScalarType, DT, Callable]] out_category = dh.op_out_categories[op] for in_dtype in dh.category_to_dtypes[in_category]: out_dtype = in_dtype if out_category == 'promoted' else xp.bool - for in_stype in dh.dtypes_to_scalars[in_dtype]: + for in_stype in dh.dtype_to_scalars[in_dtype]: yield pytest.param( f'x {symbol} s', in_dtype, in_stype, out_dtype, filters[op], - id=f'{op}({in_dtype}, {in_stype.__name__}) -> {out_dtype}', + id=make_id(op, (in_dtype, in_stype), out_dtype), ) @@ -257,13 +270,13 @@ def gen_inplace_scalar_params() -> Iterator[Tuple[str, DT, ScalarType, Callable] continue in_category = dh.op_in_categories[op] for dtype in dh.category_to_dtypes[in_category]: - for in_stype in dh.dtypes_to_scalars[dtype]: + for in_stype in dh.dtype_to_scalars[dtype]: yield pytest.param( f'x {symbol} s', dtype, in_stype, filters[op], - id=f'{op}({dtype}, {in_stype.__name__}) -> {dtype}', + id=make_id(op, (dtype, in_stype), dtype), ) From 8f3332588b3ec6baf4c02b6e00095d7d1fd92682 Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Mon, 18 Oct 2021 11:20:02 +0100 Subject: [PATCH 36/41] Pass func/op names in params --- array_api_tests/test_type_promotion.py | 61 +++++++++++++------------- 1 file changed, 30 insertions(+), 31 deletions(-) diff --git a/array_api_tests/test_type_promotion.py b/array_api_tests/test_type_promotion.py index 2b7d5e0b..825f267f 100644 --- a/array_api_tests/test_type_promotion.py +++ b/array_api_tests/test_type_promotion.py @@ -2,7 +2,7 @@ https://data-apis.github.io/array-api/latest/API_specification/type_promotion.html """ from collections import defaultdict -from typing import Iterator, Tuple, Callable, Type, Union +from typing import Iterator, Tuple, Type, Union import pytest from hypothesis import assume, given, reject @@ -54,9 +54,8 @@ def make_id( return f'{func_name}({f_args}) -> {f_out_dtype}' -def gen_func_params() -> Iterator[Tuple[Callable, Tuple[DT, ...], DT, Callable]]: +def gen_func_params() -> Iterator[Tuple[str, Tuple[DT, ...], DT]]: for func_name in elementwise_functions.__all__: - func = getattr(xp, func_name) in_category = dh.func_in_categories[func_name] out_category = dh.func_out_categories[func_name] valid_in_dtypes = dh.category_to_dtypes[in_category] @@ -65,10 +64,9 @@ def gen_func_params() -> Iterator[Tuple[Callable, Tuple[DT, ...], DT, Callable]] for in_dtype in valid_in_dtypes: out_dtype = in_dtype if out_category == 'promoted' else xp.bool yield pytest.param( - func, + func_name, (in_dtype,), out_dtype, - filters[func_name], id=make_id(func_name, (in_dtype,), out_dtype), ) elif ndtypes == 2: @@ -78,21 +76,20 @@ def gen_func_params() -> Iterator[Tuple[Callable, Tuple[DT, ...], DT, Callable]] promoted_dtype if out_category == 'promoted' else xp.bool ) yield pytest.param( - func, + func_name, (in_dtype1, in_dtype2), out_dtype, - filters[func_name], id=make_id(func_name, (in_dtype1, in_dtype2), out_dtype), ) else: raise NotImplementedError() -@pytest.mark.parametrize('func, in_dtypes, out_dtype, x_filter', gen_func_params()) +@pytest.mark.parametrize('func_name, in_dtypes, out_dtype', gen_func_params()) @given(data=st.data()) -def test_func_returns_array_with_correct_dtype( - func, in_dtypes, out_dtype, x_filter, data -): +def test_func_returns_array_with_correct_dtype(func_name, in_dtypes, out_dtype, data): + func = getattr(xp, func_name) + x_filter = filters[func_name] if len(in_dtypes) == 1: x = data.draw( xps.arrays(dtype=in_dtypes[0], shape=hh.shapes).filter(x_filter), label='x' @@ -115,7 +112,7 @@ def test_func_returns_array_with_correct_dtype( assert out.dtype == out_dtype, f'{out.dtype=!s}, but should be {out_dtype}' -def gen_op_params() -> Iterator[Tuple[str, Tuple[DT, ...], DT, Callable]]: +def gen_op_params() -> Iterator[Tuple[str, str, Tuple[DT, ...], DT]]: op_to_symbol = {**dh.unary_op_to_symbol, **dh.binary_op_to_symbol} for op, symbol in op_to_symbol.items(): if op == '__matmul__': @@ -128,10 +125,10 @@ def gen_op_params() -> Iterator[Tuple[str, Tuple[DT, ...], DT, Callable]]: for in_dtype in valid_in_dtypes: out_dtype = in_dtype if out_category == 'promoted' else xp.bool yield pytest.param( + op, f'{symbol}x', (in_dtype,), out_dtype, - filters[op], id=make_id(op, (in_dtype,), out_dtype), ) else: @@ -141,28 +138,29 @@ def gen_op_params() -> Iterator[Tuple[str, Tuple[DT, ...], DT, Callable]]: promoted_dtype if out_category == 'promoted' else xp.bool ) yield pytest.param( + op, f'x1 {symbol} x2', (in_dtype1, in_dtype2), out_dtype, - filters[op], id=make_id(op, (in_dtype1, in_dtype2), out_dtype), ) # We generate params for abs seperately as it does not have an associated symbol for in_dtype in dh.category_to_dtypes[dh.op_in_categories['__abs__']]: yield pytest.param( + '__abs__', 'abs(x)', (in_dtype,), in_dtype, - filters['__abs__'], id=make_id('__abs__', (in_dtype,), in_dtype), ) -@pytest.mark.parametrize('expr, in_dtypes, out_dtype, x_filter', gen_op_params()) +@pytest.mark.parametrize('op, expr, in_dtypes, out_dtype', gen_op_params()) @given(data=st.data()) def test_operator_returns_array_with_correct_dtype( - expr, in_dtypes, out_dtype, x_filter, data + op, expr, in_dtypes, out_dtype, data ): + x_filter = filters[op] if len(in_dtypes) == 1: x = data.draw( xps.arrays(dtype=in_dtypes[0], shape=hh.shapes).filter(x_filter), label='x' @@ -184,7 +182,7 @@ def test_operator_returns_array_with_correct_dtype( assert out.dtype == out_dtype, f'{out.dtype=!s}, but should be {out_dtype}' -def gen_inplace_params() -> Iterator[Tuple[str, Tuple[DT, ...], DT, Callable]]: +def gen_inplace_params() -> Iterator[Tuple[str, str, Tuple[DT, ...], DT]]: for op, symbol in dh.inplace_op_to_symbol.items(): if op == '__imatmul__': continue @@ -197,20 +195,21 @@ def gen_inplace_params() -> Iterator[Tuple[str, Tuple[DT, ...], DT, Callable]]: and in_dtype2 in valid_in_dtypes ): yield pytest.param( + op, f'x1 {symbol} x2', (in_dtype1, in_dtype2), promoted_dtype, - filters[op], id=make_id(op, (in_dtype1, in_dtype2), promoted_dtype), ) -@pytest.mark.parametrize('expr, in_dtypes, out_dtype, x_filter', gen_inplace_params()) +@pytest.mark.parametrize('op, expr, in_dtypes, out_dtype', gen_inplace_params()) @given(shapes=hh.mutually_broadcastable_shapes(2), data=st.data()) def test_inplace_operator_returns_array_with_correct_dtype( - expr, in_dtypes, out_dtype, x_filter, shapes, data + op, expr, in_dtypes, out_dtype, shapes, data ): assume(len(shapes[0]) >= len(shapes[1])) + x_filter = filters[op] x1 = data.draw( xps.arrays(dtype=in_dtypes[0], shape=shapes[0]).filter(x_filter), label='x1' ) @@ -226,7 +225,7 @@ def test_inplace_operator_returns_array_with_correct_dtype( assert x1.dtype == out_dtype, f'{x1.dtype=!s}, but should be {out_dtype}' -def gen_op_scalar_params() -> Iterator[Tuple[str, DT, ScalarType, DT, Callable]]: +def gen_op_scalar_params() -> Iterator[Tuple[str, str, DT, ScalarType, DT]]: for op, symbol in dh.binary_op_to_symbol.items(): if op == '__matmul__': continue @@ -236,22 +235,23 @@ def gen_op_scalar_params() -> Iterator[Tuple[str, DT, ScalarType, DT, Callable]] out_dtype = in_dtype if out_category == 'promoted' else xp.bool for in_stype in dh.dtype_to_scalars[in_dtype]: yield pytest.param( + op, f'x {symbol} s', in_dtype, in_stype, out_dtype, - filters[op], id=make_id(op, (in_dtype, in_stype), out_dtype), ) @pytest.mark.parametrize( - 'expr, in_dtype, in_stype, out_dtype, x_filter', gen_op_scalar_params() + 'op, expr, in_dtype, in_stype, out_dtype', gen_op_scalar_params() ) @given(data=st.data()) def test_binary_operator_promotes_python_scalars( - expr, in_dtype, in_stype, out_dtype, x_filter, data + op, expr, in_dtype, in_stype, out_dtype, data ): + x_filter = filters[op] kw = {k: in_stype is float for k in ('allow_nan', 'allow_infinity')} s = data.draw(xps.from_dtype(in_dtype, **kw).map(in_stype), label='scalar') x = data.draw( @@ -264,7 +264,7 @@ def test_binary_operator_promotes_python_scalars( assert out.dtype == out_dtype, f'{out.dtype=!s}, but should be {out_dtype}' -def gen_inplace_scalar_params() -> Iterator[Tuple[str, DT, ScalarType, Callable]]: +def gen_inplace_scalar_params() -> Iterator[Tuple[str, str, DT, ScalarType]]: for op, symbol in dh.inplace_op_to_symbol.items(): if op == '__imatmul__': continue @@ -272,19 +272,18 @@ def gen_inplace_scalar_params() -> Iterator[Tuple[str, DT, ScalarType, Callable] for dtype in dh.category_to_dtypes[in_category]: for in_stype in dh.dtype_to_scalars[dtype]: yield pytest.param( + op, f'x {symbol} s', dtype, in_stype, - filters[op], id=make_id(op, (dtype, in_stype), dtype), ) -@pytest.mark.parametrize('expr, dtype, in_stype, x_filter', gen_inplace_scalar_params()) +@pytest.mark.parametrize('op, expr, dtype, in_stype', gen_inplace_scalar_params()) @given(data=st.data()) -def test_inplace_operator_promotes_python_scalars( - expr, dtype, in_stype, x_filter, data -): +def test_inplace_operator_promotes_python_scalars(op, expr, dtype, in_stype, data): + x_filter = filters[op] kw = {k: in_stype is float for k in ('allow_nan', 'allow_infinity')} s = data.draw(xps.from_dtype(dtype, **kw).map(in_stype), label='scalar') x = data.draw(xps.arrays(dtype=dtype, shape=hh.shapes).filter(x_filter), label='x') From 7acf87a996538aa543e1970c0cf22de1fdadc7df Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Mon, 18 Oct 2021 11:43:44 +0100 Subject: [PATCH 37/41] Shorten test method names --- array_api_tests/test_type_promotion.py | 16 +++++----------- 1 file changed, 5 insertions(+), 11 deletions(-) diff --git a/array_api_tests/test_type_promotion.py b/array_api_tests/test_type_promotion.py index 825f267f..b011f4b7 100644 --- a/array_api_tests/test_type_promotion.py +++ b/array_api_tests/test_type_promotion.py @@ -87,7 +87,7 @@ def gen_func_params() -> Iterator[Tuple[str, Tuple[DT, ...], DT]]: @pytest.mark.parametrize('func_name, in_dtypes, out_dtype', gen_func_params()) @given(data=st.data()) -def test_func_returns_array_with_correct_dtype(func_name, in_dtypes, out_dtype, data): +def test_func_promotion(func_name, in_dtypes, out_dtype, data): func = getattr(xp, func_name) x_filter = filters[func_name] if len(in_dtypes) == 1: @@ -157,9 +157,7 @@ def gen_op_params() -> Iterator[Tuple[str, str, Tuple[DT, ...], DT]]: @pytest.mark.parametrize('op, expr, in_dtypes, out_dtype', gen_op_params()) @given(data=st.data()) -def test_operator_returns_array_with_correct_dtype( - op, expr, in_dtypes, out_dtype, data -): +def test_op_promotion(op, expr, in_dtypes, out_dtype, data): x_filter = filters[op] if len(in_dtypes) == 1: x = data.draw( @@ -205,9 +203,7 @@ def gen_inplace_params() -> Iterator[Tuple[str, str, Tuple[DT, ...], DT]]: @pytest.mark.parametrize('op, expr, in_dtypes, out_dtype', gen_inplace_params()) @given(shapes=hh.mutually_broadcastable_shapes(2), data=st.data()) -def test_inplace_operator_returns_array_with_correct_dtype( - op, expr, in_dtypes, out_dtype, shapes, data -): +def test_inplace_op_promotion(op, expr, in_dtypes, out_dtype, shapes, data): assume(len(shapes[0]) >= len(shapes[1])) x_filter = filters[op] x1 = data.draw( @@ -248,9 +244,7 @@ def gen_op_scalar_params() -> Iterator[Tuple[str, str, DT, ScalarType, DT]]: 'op, expr, in_dtype, in_stype, out_dtype', gen_op_scalar_params() ) @given(data=st.data()) -def test_binary_operator_promotes_python_scalars( - op, expr, in_dtype, in_stype, out_dtype, data -): +def test_op_scalar_promotion(op, expr, in_dtype, in_stype, out_dtype, data): x_filter = filters[op] kw = {k: in_stype is float for k in ('allow_nan', 'allow_infinity')} s = data.draw(xps.from_dtype(in_dtype, **kw).map(in_stype), label='scalar') @@ -282,7 +276,7 @@ def gen_inplace_scalar_params() -> Iterator[Tuple[str, str, DT, ScalarType]]: @pytest.mark.parametrize('op, expr, dtype, in_stype', gen_inplace_scalar_params()) @given(data=st.data()) -def test_inplace_operator_promotes_python_scalars(op, expr, dtype, in_stype, data): +def test_inplace_op_scalar_promotion(op, expr, dtype, in_stype, data): x_filter = filters[op] kw = {k: in_stype is float for k in ('allow_nan', 'allow_infinity')} s = data.draw(xps.from_dtype(dtype, **kw).map(in_stype), label='scalar') From 0ff687e0136a19d6cce185c28bbabd5a68c73666 Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Mon, 18 Oct 2021 13:49:13 +0100 Subject: [PATCH 38/41] Replace input dtype categories with dtype objects --- array_api_tests/dtype_helpers.py | 141 +++++++++++-------------- array_api_tests/test_broadcasting.py | 4 +- array_api_tests/test_signatures.py | 11 +- array_api_tests/test_type_promotion.py | 21 ++-- 4 files changed, 79 insertions(+), 98 deletions(-) diff --git a/array_api_tests/dtype_helpers.py b/array_api_tests/dtype_helpers.py index 614bb403..c28c841e 100644 --- a/array_api_tests/dtype_helpers.py +++ b/array_api_tests/dtype_helpers.py @@ -16,14 +16,11 @@ 'is_int_dtype', 'is_float_dtype', 'dtype_ranges', - 'category_to_dtypes', 'promotion_table', 'dtype_nbits', 'dtype_signed', - 'func_in_categories', + 'func_in_dtypes', 'func_out_categories', - 'op_in_categories', - 'op_out_categories', 'op_to_func', 'binary_op_to_symbol', 'unary_op_to_symbol', @@ -88,16 +85,6 @@ class MinMax(NamedTuple): } -category_to_dtypes = { - 'any': all_dtypes, - 'boolean': (xp.bool,), - 'floating': float_dtypes, - 'integer': all_int_dtypes, - 'integer_or_boolean': (xp.bool,) + uint_dtypes + int_dtypes, - 'numeric': numeric_dtypes, -} - - _numeric_promotions = { # ints (xp.int8, xp.int8): xp.int8, @@ -160,63 +147,63 @@ class MinMax(NamedTuple): } -func_in_categories = { - 'abs': 'numeric', - 'acos': 'floating', - 'acosh': 'floating', - 'add': 'numeric', - 'asin': 'floating', - 'asinh': 'floating', - 'atan': 'floating', - 'atan2': 'floating', - 'atanh': 'floating', - 'bitwise_and': 'integer_or_boolean', - 'bitwise_invert': 'integer_or_boolean', - 'bitwise_left_shift': 'integer', - 'bitwise_or': 'integer_or_boolean', - 'bitwise_right_shift': 'integer', - 'bitwise_xor': 'integer_or_boolean', - 'ceil': 'numeric', - 'cos': 'floating', - 'cosh': 'floating', - 'divide': 'floating', - 'equal': 'any', - 'exp': 'floating', - 'expm1': 'floating', - 'floor': 'numeric', - 'floor_divide': 'numeric', - 'greater': 'numeric', - 'greater_equal': 'numeric', - 'isfinite': 'numeric', - 'isinf': 'numeric', - 'isnan': 'numeric', - 'less': 'numeric', - 'less_equal': 'numeric', - 'log': 'floating', - 'logaddexp': 'floating', - 'log10': 'floating', - 'log1p': 'floating', - 'log2': 'floating', - 'logical_and': 'boolean', - 'logical_not': 'boolean', - 'logical_or': 'boolean', - 'logical_xor': 'boolean', - 'multiply': 'numeric', - 'negative': 'numeric', - 'not_equal': 'any', - 'positive': 'numeric', - 'pow': 'floating', - 'remainder': 'numeric', - 'round': 'numeric', - 'sign': 'numeric', - 'sin': 'floating', - 'sinh': 'floating', - 'sqrt': 'floating', - 'square': 'numeric', - 'subtract': 'numeric', - 'tan': 'floating', - 'tanh': 'floating', - 'trunc': 'numeric', +func_in_dtypes = { + 'abs': numeric_dtypes, + 'acos': float_dtypes, + 'acosh': float_dtypes, + 'add': numeric_dtypes, + 'asin': float_dtypes, + 'asinh': float_dtypes, + 'atan': float_dtypes, + 'atan2': float_dtypes, + 'atanh': float_dtypes, + 'bitwise_and': bool_and_all_int_dtypes, + 'bitwise_invert': bool_and_all_int_dtypes, + 'bitwise_left_shift': all_int_dtypes, + 'bitwise_or': bool_and_all_int_dtypes, + 'bitwise_right_shift': all_int_dtypes, + 'bitwise_xor': bool_and_all_int_dtypes, + 'ceil': numeric_dtypes, + 'cos': float_dtypes, + 'cosh': float_dtypes, + 'divide': float_dtypes, + 'equal': all_dtypes, + 'exp': float_dtypes, + 'expm1': float_dtypes, + 'floor': numeric_dtypes, + 'floor_divide': numeric_dtypes, + 'greater': numeric_dtypes, + 'greater_equal': numeric_dtypes, + 'isfinite': numeric_dtypes, + 'isinf': numeric_dtypes, + 'isnan': numeric_dtypes, + 'less': numeric_dtypes, + 'less_equal': numeric_dtypes, + 'log': float_dtypes, + 'logaddexp': float_dtypes, + 'log10': float_dtypes, + 'log1p': float_dtypes, + 'log2': float_dtypes, + 'logical_and': (xp.bool,), + 'logical_not': (xp.bool,), + 'logical_or': (xp.bool,), + 'logical_xor': (xp.bool,), + 'multiply': numeric_dtypes, + 'negative': numeric_dtypes, + 'not_equal': all_dtypes, + 'positive': numeric_dtypes, + 'pow': float_dtypes, + 'remainder': numeric_dtypes, + 'round': numeric_dtypes, + 'sign': numeric_dtypes, + 'sin': float_dtypes, + 'sinh': float_dtypes, + 'sqrt': float_dtypes, + 'square': numeric_dtypes, + 'subtract': numeric_dtypes, + 'tan': float_dtypes, + 'tanh': float_dtypes, + 'trunc': numeric_dtypes, } @@ -337,18 +324,16 @@ class MinMax(NamedTuple): } -op_in_categories = {} -op_out_categories = {} for op, elwise_func in op_to_func.items(): - op_in_categories[op] = func_in_categories[elwise_func] - op_out_categories[op] = func_out_categories[elwise_func] + func_in_dtypes[op] = func_in_dtypes[elwise_func] + func_out_categories[op] = func_out_categories[elwise_func] inplace_op_to_symbol = {} for op, symbol in binary_op_to_symbol.items(): - if op == '__matmul__' or op_out_categories[op] == 'bool': + if op == '__matmul__' or func_out_categories[op] == 'bool': continue iop = f'__i{op[2:]}' inplace_op_to_symbol[iop] = f'{symbol}=' - op_in_categories[iop] = op_in_categories[op] - op_out_categories[iop] = op_out_categories[op] + func_in_dtypes[iop] = func_in_dtypes[op] + func_out_categories[iop] = func_out_categories[op] diff --git a/array_api_tests/test_broadcasting.py b/array_api_tests/test_broadcasting.py index c8753082..1a1c9c47 100644 --- a/array_api_tests/test_broadcasting.py +++ b/array_api_tests/test_broadcasting.py @@ -10,7 +10,7 @@ from .hypothesis_helpers import shapes, FILTER_UNDEFINED_DTYPES from .pytest_helpers import raises, doesnt_raise, nargs -from .dtype_helpers import func_in_categories, category_to_dtypes +from .dtype_helpers import func_in_dtypes from .function_stubs import elementwise_functions from . import _array_module from ._array_module import ones, _UndefinedStub @@ -115,7 +115,7 @@ def test_broadcasting_hypothesis(func_name, shape1, shape2, data): # Internal consistency checks assert nargs(func_name) == 2 - dtype = data.draw(sampled_from(category_to_dtypes[func_in_categories[func_name]])) + dtype = data.draw(sampled_from(func_in_dtypes[func_name])) if FILTER_UNDEFINED_DTYPES: assume(not isinstance(dtype, _UndefinedStub)) func = getattr(_array_module, func_name) diff --git a/array_api_tests/test_signatures.py b/array_api_tests/test_signatures.py index 01968aef..c235255c 100644 --- a/array_api_tests/test_signatures.py +++ b/array_api_tests/test_signatures.py @@ -4,7 +4,7 @@ from ._array_module import mod, mod_name, ones, eye, float64, bool, int64 from .pytest_helpers import raises, doesnt_raise -from .dtype_helpers import func_in_categories, op_to_func +from . import dtype_helpers as dh from . import function_stubs @@ -160,12 +160,13 @@ def test_function_positional_args(name): dtype = None if (name.startswith('__i') and name not in ['__int__', '__invert__', '__index__'] or name.startswith('__r') and name != '__rshift__'): - n = op_to_func[name[:2] + name[3:]] + n = dh.op_to_func[name[:2] + name[3:]] else: - n = op_to_func.get(name, name) - if 'boolean' in func_in_categories.get(n, 'floating'): + n = dh.op_to_func.get(name, name) + in_dtypes = dh.func_in_dtypes.get(n, dh.float_dtypes) + if bool in in_dtypes: dtype = bool - elif 'integer' in func_in_categories.get(n, 'floating'): + elif all(d in in_dtypes for d in dh.all_int_dtypes): dtype = int64 if array_method(name): diff --git a/array_api_tests/test_type_promotion.py b/array_api_tests/test_type_promotion.py index b011f4b7..805047d6 100644 --- a/array_api_tests/test_type_promotion.py +++ b/array_api_tests/test_type_promotion.py @@ -56,9 +56,8 @@ def make_id( def gen_func_params() -> Iterator[Tuple[str, Tuple[DT, ...], DT]]: for func_name in elementwise_functions.__all__: - in_category = dh.func_in_categories[func_name] + valid_in_dtypes = dh.func_in_dtypes[func_name] out_category = dh.func_out_categories[func_name] - valid_in_dtypes = dh.category_to_dtypes[in_category] ndtypes = nargs(func_name) if ndtypes == 1: for in_dtype in valid_in_dtypes: @@ -117,9 +116,8 @@ def gen_op_params() -> Iterator[Tuple[str, str, Tuple[DT, ...], DT]]: for op, symbol in op_to_symbol.items(): if op == '__matmul__': continue - in_category = dh.op_in_categories[op] - out_category = dh.op_out_categories[op] - valid_in_dtypes = dh.category_to_dtypes[in_category] + valid_in_dtypes = dh.func_in_dtypes[op] + out_category = dh.func_out_categories[op] ndtypes = nargs(op) if ndtypes == 1: for in_dtype in valid_in_dtypes: @@ -145,7 +143,7 @@ def gen_op_params() -> Iterator[Tuple[str, str, Tuple[DT, ...], DT]]: id=make_id(op, (in_dtype1, in_dtype2), out_dtype), ) # We generate params for abs seperately as it does not have an associated symbol - for in_dtype in dh.category_to_dtypes[dh.op_in_categories['__abs__']]: + for in_dtype in dh.func_in_dtypes['__abs__']: yield pytest.param( '__abs__', 'abs(x)', @@ -184,8 +182,7 @@ def gen_inplace_params() -> Iterator[Tuple[str, str, Tuple[DT, ...], DT]]: for op, symbol in dh.inplace_op_to_symbol.items(): if op == '__imatmul__': continue - in_category = dh.op_in_categories[op] - valid_in_dtypes = dh.category_to_dtypes[in_category] + valid_in_dtypes = dh.func_in_dtypes[op] for (in_dtype1, in_dtype2), promoted_dtype in dh.promotion_table.items(): if ( in_dtype1 == promoted_dtype @@ -225,9 +222,8 @@ def gen_op_scalar_params() -> Iterator[Tuple[str, str, DT, ScalarType, DT]]: for op, symbol in dh.binary_op_to_symbol.items(): if op == '__matmul__': continue - in_category = dh.op_in_categories[op] - out_category = dh.op_out_categories[op] - for in_dtype in dh.category_to_dtypes[in_category]: + out_category = dh.func_out_categories[op] + for in_dtype in dh.func_in_dtypes[op]: out_dtype = in_dtype if out_category == 'promoted' else xp.bool for in_stype in dh.dtype_to_scalars[in_dtype]: yield pytest.param( @@ -262,8 +258,7 @@ def gen_inplace_scalar_params() -> Iterator[Tuple[str, str, DT, ScalarType]]: for op, symbol in dh.inplace_op_to_symbol.items(): if op == '__imatmul__': continue - in_category = dh.op_in_categories[op] - for dtype in dh.category_to_dtypes[in_category]: + for dtype in dh.func_in_dtypes[op]: for in_stype in dh.dtype_to_scalars[dtype]: yield pytest.param( op, From 49279d546d2bc38e7702245dda9d8ab9a24fdf4f Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Mon, 18 Oct 2021 13:59:55 +0100 Subject: [PATCH 39/41] Only use `_op_to_func` privately --- array_api_tests/dtype_helpers.py | 5 ++--- array_api_tests/test_signatures.py | 4 ++-- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/array_api_tests/dtype_helpers.py b/array_api_tests/dtype_helpers.py index c28c841e..7882b845 100644 --- a/array_api_tests/dtype_helpers.py +++ b/array_api_tests/dtype_helpers.py @@ -21,7 +21,6 @@ 'dtype_signed', 'func_in_dtypes', 'func_out_categories', - 'op_to_func', 'binary_op_to_symbol', 'unary_op_to_symbol', 'inplace_op_to_symbol', @@ -297,7 +296,7 @@ class MinMax(NamedTuple): } -op_to_func = { +_op_to_func = { '__abs__': 'abs', '__add__': 'add', '__and__': 'bitwise_and', @@ -324,7 +323,7 @@ class MinMax(NamedTuple): } -for op, elwise_func in op_to_func.items(): +for op, elwise_func in _op_to_func.items(): func_in_dtypes[op] = func_in_dtypes[elwise_func] func_out_categories[op] = func_out_categories[elwise_func] diff --git a/array_api_tests/test_signatures.py b/array_api_tests/test_signatures.py index c235255c..e8106985 100644 --- a/array_api_tests/test_signatures.py +++ b/array_api_tests/test_signatures.py @@ -160,9 +160,9 @@ def test_function_positional_args(name): dtype = None if (name.startswith('__i') and name not in ['__int__', '__invert__', '__index__'] or name.startswith('__r') and name != '__rshift__'): - n = dh.op_to_func[name[:2] + name[3:]] + n = f'__{name[3:]}' else: - n = dh.op_to_func.get(name, name) + n = name in_dtypes = dh.func_in_dtypes.get(n, dh.float_dtypes) if bool in in_dtypes: dtype = bool From 086cc10833eb0a387d616b462024a0f3dd9a6507 Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Mon, 18 Oct 2021 14:11:20 +0100 Subject: [PATCH 40/41] Replace `func_out_categories` with `func_returns_bool` --- array_api_tests/dtype_helpers.py | 122 ++++++++++++------------- array_api_tests/test_type_promotion.py | 15 +-- 2 files changed, 66 insertions(+), 71 deletions(-) diff --git a/array_api_tests/dtype_helpers.py b/array_api_tests/dtype_helpers.py index 7882b845..b81edb7b 100644 --- a/array_api_tests/dtype_helpers.py +++ b/array_api_tests/dtype_helpers.py @@ -20,7 +20,7 @@ 'dtype_nbits', 'dtype_signed', 'func_in_dtypes', - 'func_out_categories', + 'func_returns_bool', 'binary_op_to_symbol', 'unary_op_to_symbol', 'inplace_op_to_symbol', @@ -206,63 +206,63 @@ class MinMax(NamedTuple): } -func_out_categories = { - 'abs': 'promoted', - 'acos': 'promoted', - 'acosh': 'promoted', - 'add': 'promoted', - 'asin': 'promoted', - 'asinh': 'promoted', - 'atan': 'promoted', - 'atan2': 'promoted', - 'atanh': 'promoted', - 'bitwise_and': 'promoted', - 'bitwise_invert': 'promoted', - 'bitwise_left_shift': 'promoted', - 'bitwise_or': 'promoted', - 'bitwise_right_shift': 'promoted', - 'bitwise_xor': 'promoted', - 'ceil': 'promoted', - 'cos': 'promoted', - 'cosh': 'promoted', - 'divide': 'promoted', - 'equal': 'bool', - 'exp': 'promoted', - 'expm1': 'promoted', - 'floor': 'promoted', - 'floor_divide': 'promoted', - 'greater': 'bool', - 'greater_equal': 'bool', - 'isfinite': 'bool', - 'isinf': 'bool', - 'isnan': 'bool', - 'less': 'bool', - 'less_equal': 'bool', - 'log': 'promoted', - 'logaddexp': 'promoted', - 'log10': 'promoted', - 'log1p': 'promoted', - 'log2': 'promoted', - 'logical_and': 'bool', - 'logical_not': 'bool', - 'logical_or': 'bool', - 'logical_xor': 'bool', - 'multiply': 'promoted', - 'negative': 'promoted', - 'not_equal': 'bool', - 'positive': 'promoted', - 'pow': 'promoted', - 'remainder': 'promoted', - 'round': 'promoted', - 'sign': 'promoted', - 'sin': 'promoted', - 'sinh': 'promoted', - 'sqrt': 'promoted', - 'square': 'promoted', - 'subtract': 'promoted', - 'tan': 'promoted', - 'tanh': 'promoted', - 'trunc': 'promoted', +func_returns_bool = { + 'abs': False, + 'acos': False, + 'acosh': False, + 'add': False, + 'asin': False, + 'asinh': False, + 'atan': False, + 'atan2': False, + 'atanh': False, + 'bitwise_and': False, + 'bitwise_invert': False, + 'bitwise_left_shift': False, + 'bitwise_or': False, + 'bitwise_right_shift': False, + 'bitwise_xor': False, + 'ceil': False, + 'cos': False, + 'cosh': False, + 'divide': False, + 'equal': True, + 'exp': False, + 'expm1': False, + 'floor': False, + 'floor_divide': False, + 'greater': True, + 'greater_equal': True, + 'isfinite': True, + 'isinf': True, + 'isnan': True, + 'less': True, + 'less_equal': True, + 'log': False, + 'logaddexp': False, + 'log10': False, + 'log1p': False, + 'log2': False, + 'logical_and': True, + 'logical_not': True, + 'logical_or': True, + 'logical_xor': True, + 'multiply': False, + 'negative': False, + 'not_equal': True, + 'positive': False, + 'pow': False, + 'remainder': False, + 'round': False, + 'sign': False, + 'sin': False, + 'sinh': False, + 'sqrt': False, + 'square': False, + 'subtract': False, + 'tan': False, + 'tanh': False, + 'trunc': False, } @@ -325,14 +325,14 @@ class MinMax(NamedTuple): for op, elwise_func in _op_to_func.items(): func_in_dtypes[op] = func_in_dtypes[elwise_func] - func_out_categories[op] = func_out_categories[elwise_func] + func_returns_bool[op] = func_returns_bool[elwise_func] inplace_op_to_symbol = {} for op, symbol in binary_op_to_symbol.items(): - if op == '__matmul__' or func_out_categories[op] == 'bool': + if op == '__matmul__' or func_returns_bool[op]: continue iop = f'__i{op[2:]}' inplace_op_to_symbol[iop] = f'{symbol}=' func_in_dtypes[iop] = func_in_dtypes[op] - func_out_categories[iop] = func_out_categories[op] + func_returns_bool[iop] = func_returns_bool[op] diff --git a/array_api_tests/test_type_promotion.py b/array_api_tests/test_type_promotion.py index 805047d6..165ba420 100644 --- a/array_api_tests/test_type_promotion.py +++ b/array_api_tests/test_type_promotion.py @@ -57,11 +57,10 @@ def make_id( def gen_func_params() -> Iterator[Tuple[str, Tuple[DT, ...], DT]]: for func_name in elementwise_functions.__all__: valid_in_dtypes = dh.func_in_dtypes[func_name] - out_category = dh.func_out_categories[func_name] ndtypes = nargs(func_name) if ndtypes == 1: for in_dtype in valid_in_dtypes: - out_dtype = in_dtype if out_category == 'promoted' else xp.bool + out_dtype = xp.bool if dh.func_returns_bool[func_name] else in_dtype yield pytest.param( func_name, (in_dtype,), @@ -72,7 +71,7 @@ def gen_func_params() -> Iterator[Tuple[str, Tuple[DT, ...], DT]]: for (in_dtype1, in_dtype2), promoted_dtype in dh.promotion_table.items(): if in_dtype1 in valid_in_dtypes and in_dtype2 in valid_in_dtypes: out_dtype = ( - promoted_dtype if out_category == 'promoted' else xp.bool + xp.bool if dh.func_returns_bool[func_name] else promoted_dtype ) yield pytest.param( func_name, @@ -117,11 +116,10 @@ def gen_op_params() -> Iterator[Tuple[str, str, Tuple[DT, ...], DT]]: if op == '__matmul__': continue valid_in_dtypes = dh.func_in_dtypes[op] - out_category = dh.func_out_categories[op] ndtypes = nargs(op) if ndtypes == 1: for in_dtype in valid_in_dtypes: - out_dtype = in_dtype if out_category == 'promoted' else xp.bool + out_dtype = xp.bool if dh.func_returns_bool[op] else in_dtype yield pytest.param( op, f'{symbol}x', @@ -132,9 +130,7 @@ def gen_op_params() -> Iterator[Tuple[str, str, Tuple[DT, ...], DT]]: else: for (in_dtype1, in_dtype2), promoted_dtype in dh.promotion_table.items(): if in_dtype1 in valid_in_dtypes and in_dtype2 in valid_in_dtypes: - out_dtype = ( - promoted_dtype if out_category == 'promoted' else xp.bool - ) + out_dtype = xp.bool if dh.func_returns_bool[op] else promoted_dtype yield pytest.param( op, f'x1 {symbol} x2', @@ -222,9 +218,8 @@ def gen_op_scalar_params() -> Iterator[Tuple[str, str, DT, ScalarType, DT]]: for op, symbol in dh.binary_op_to_symbol.items(): if op == '__matmul__': continue - out_category = dh.func_out_categories[op] for in_dtype in dh.func_in_dtypes[op]: - out_dtype = in_dtype if out_category == 'promoted' else xp.bool + out_dtype = xp.bool if dh.func_returns_bool[op] else in_dtype for in_stype in dh.dtype_to_scalars[in_dtype]: yield pytest.param( op, From b7c0f51f73457cb003d7ac788281b213342793c3 Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Mon, 18 Oct 2021 15:33:35 +0100 Subject: [PATCH 41/41] Replace param generators with constructed lists --- array_api_tests/test_type_promotion.py | 234 +++++++++++++------------ 1 file changed, 120 insertions(+), 114 deletions(-) diff --git a/array_api_tests/test_type_promotion.py b/array_api_tests/test_type_promotion.py index 165ba420..ea4e6eb8 100644 --- a/array_api_tests/test_type_promotion.py +++ b/array_api_tests/test_type_promotion.py @@ -2,7 +2,7 @@ https://data-apis.github.io/array-api/latest/API_specification/type_promotion.html """ from collections import defaultdict -from typing import Iterator, Tuple, Type, Union +from typing import Tuple, Type, Union, List import pytest from hypothesis import assume, given, reject @@ -54,36 +54,38 @@ def make_id( return f'{func_name}({f_args}) -> {f_out_dtype}' -def gen_func_params() -> Iterator[Tuple[str, Tuple[DT, ...], DT]]: - for func_name in elementwise_functions.__all__: - valid_in_dtypes = dh.func_in_dtypes[func_name] - ndtypes = nargs(func_name) - if ndtypes == 1: - for in_dtype in valid_in_dtypes: - out_dtype = xp.bool if dh.func_returns_bool[func_name] else in_dtype - yield pytest.param( +func_params: List[Tuple[str, Tuple[DT, ...], DT]] = [] +for func_name in elementwise_functions.__all__: + valid_in_dtypes = dh.func_in_dtypes[func_name] + ndtypes = nargs(func_name) + if ndtypes == 1: + for in_dtype in valid_in_dtypes: + out_dtype = xp.bool if dh.func_returns_bool[func_name] else in_dtype + p = pytest.param( + func_name, + (in_dtype,), + out_dtype, + id=make_id(func_name, (in_dtype,), out_dtype), + ) + func_params.append(p) + elif ndtypes == 2: + for (in_dtype1, in_dtype2), promoted_dtype in dh.promotion_table.items(): + if in_dtype1 in valid_in_dtypes and in_dtype2 in valid_in_dtypes: + out_dtype = ( + xp.bool if dh.func_returns_bool[func_name] else promoted_dtype + ) + p = pytest.param( func_name, - (in_dtype,), + (in_dtype1, in_dtype2), out_dtype, - id=make_id(func_name, (in_dtype,), out_dtype), + id=make_id(func_name, (in_dtype1, in_dtype2), out_dtype), ) - elif ndtypes == 2: - for (in_dtype1, in_dtype2), promoted_dtype in dh.promotion_table.items(): - if in_dtype1 in valid_in_dtypes and in_dtype2 in valid_in_dtypes: - out_dtype = ( - xp.bool if dh.func_returns_bool[func_name] else promoted_dtype - ) - yield pytest.param( - func_name, - (in_dtype1, in_dtype2), - out_dtype, - id=make_id(func_name, (in_dtype1, in_dtype2), out_dtype), - ) - else: - raise NotImplementedError() - - -@pytest.mark.parametrize('func_name, in_dtypes, out_dtype', gen_func_params()) + func_params.append(p) + else: + raise NotImplementedError() + + +@pytest.mark.parametrize('func_name, in_dtypes, out_dtype', func_params) @given(data=st.data()) def test_func_promotion(func_name, in_dtypes, out_dtype, data): func = getattr(xp, func_name) @@ -110,46 +112,49 @@ def test_func_promotion(func_name, in_dtypes, out_dtype, data): assert out.dtype == out_dtype, f'{out.dtype=!s}, but should be {out_dtype}' -def gen_op_params() -> Iterator[Tuple[str, str, Tuple[DT, ...], DT]]: - op_to_symbol = {**dh.unary_op_to_symbol, **dh.binary_op_to_symbol} - for op, symbol in op_to_symbol.items(): - if op == '__matmul__': - continue - valid_in_dtypes = dh.func_in_dtypes[op] - ndtypes = nargs(op) - if ndtypes == 1: - for in_dtype in valid_in_dtypes: - out_dtype = xp.bool if dh.func_returns_bool[op] else in_dtype - yield pytest.param( +op_params: List[Tuple[str, str, Tuple[DT, ...], DT]] = [] +op_to_symbol = {**dh.unary_op_to_symbol, **dh.binary_op_to_symbol} +for op, symbol in op_to_symbol.items(): + if op == '__matmul__': + continue + valid_in_dtypes = dh.func_in_dtypes[op] + ndtypes = nargs(op) + if ndtypes == 1: + for in_dtype in valid_in_dtypes: + out_dtype = xp.bool if dh.func_returns_bool[op] else in_dtype + p = pytest.param( + op, + f'{symbol}x', + (in_dtype,), + out_dtype, + id=make_id(op, (in_dtype,), out_dtype), + ) + op_params.append(p) + else: + for (in_dtype1, in_dtype2), promoted_dtype in dh.promotion_table.items(): + if in_dtype1 in valid_in_dtypes and in_dtype2 in valid_in_dtypes: + out_dtype = xp.bool if dh.func_returns_bool[op] else promoted_dtype + p = pytest.param( op, - f'{symbol}x', - (in_dtype,), + f'x1 {symbol} x2', + (in_dtype1, in_dtype2), out_dtype, - id=make_id(op, (in_dtype,), out_dtype), + id=make_id(op, (in_dtype1, in_dtype2), out_dtype), ) - else: - for (in_dtype1, in_dtype2), promoted_dtype in dh.promotion_table.items(): - if in_dtype1 in valid_in_dtypes and in_dtype2 in valid_in_dtypes: - out_dtype = xp.bool if dh.func_returns_bool[op] else promoted_dtype - yield pytest.param( - op, - f'x1 {symbol} x2', - (in_dtype1, in_dtype2), - out_dtype, - id=make_id(op, (in_dtype1, in_dtype2), out_dtype), - ) - # We generate params for abs seperately as it does not have an associated symbol - for in_dtype in dh.func_in_dtypes['__abs__']: - yield pytest.param( - '__abs__', - 'abs(x)', - (in_dtype,), - in_dtype, - id=make_id('__abs__', (in_dtype,), in_dtype), - ) + op_params.append(p) +# We generate params for abs seperately as it does not have an associated symbol +for in_dtype in dh.func_in_dtypes['__abs__']: + p = pytest.param( + '__abs__', + 'abs(x)', + (in_dtype,), + in_dtype, + id=make_id('__abs__', (in_dtype,), in_dtype), + ) + op_params.append(p) -@pytest.mark.parametrize('op, expr, in_dtypes, out_dtype', gen_op_params()) +@pytest.mark.parametrize('op, expr, in_dtypes, out_dtype', op_params) @given(data=st.data()) def test_op_promotion(op, expr, in_dtypes, out_dtype, data): x_filter = filters[op] @@ -174,27 +179,28 @@ def test_op_promotion(op, expr, in_dtypes, out_dtype, data): assert out.dtype == out_dtype, f'{out.dtype=!s}, but should be {out_dtype}' -def gen_inplace_params() -> Iterator[Tuple[str, str, Tuple[DT, ...], DT]]: - for op, symbol in dh.inplace_op_to_symbol.items(): - if op == '__imatmul__': - continue - valid_in_dtypes = dh.func_in_dtypes[op] - for (in_dtype1, in_dtype2), promoted_dtype in dh.promotion_table.items(): - if ( - in_dtype1 == promoted_dtype - and in_dtype1 in valid_in_dtypes - and in_dtype2 in valid_in_dtypes - ): - yield pytest.param( - op, - f'x1 {symbol} x2', - (in_dtype1, in_dtype2), - promoted_dtype, - id=make_id(op, (in_dtype1, in_dtype2), promoted_dtype), - ) +inplace_params: List[Tuple[str, str, Tuple[DT, ...], DT]] = [] +for op, symbol in dh.inplace_op_to_symbol.items(): + if op == '__imatmul__': + continue + valid_in_dtypes = dh.func_in_dtypes[op] + for (in_dtype1, in_dtype2), promoted_dtype in dh.promotion_table.items(): + if ( + in_dtype1 == promoted_dtype + and in_dtype1 in valid_in_dtypes + and in_dtype2 in valid_in_dtypes + ): + p = pytest.param( + op, + f'x1 {symbol} x2', + (in_dtype1, in_dtype2), + promoted_dtype, + id=make_id(op, (in_dtype1, in_dtype2), promoted_dtype), + ) + inplace_params.append(p) -@pytest.mark.parametrize('op, expr, in_dtypes, out_dtype', gen_inplace_params()) +@pytest.mark.parametrize('op, expr, in_dtypes, out_dtype', inplace_params) @given(shapes=hh.mutually_broadcastable_shapes(2), data=st.data()) def test_inplace_op_promotion(op, expr, in_dtypes, out_dtype, shapes, data): assume(len(shapes[0]) >= len(shapes[1])) @@ -214,26 +220,25 @@ def test_inplace_op_promotion(op, expr, in_dtypes, out_dtype, shapes, data): assert x1.dtype == out_dtype, f'{x1.dtype=!s}, but should be {out_dtype}' -def gen_op_scalar_params() -> Iterator[Tuple[str, str, DT, ScalarType, DT]]: - for op, symbol in dh.binary_op_to_symbol.items(): - if op == '__matmul__': - continue - for in_dtype in dh.func_in_dtypes[op]: - out_dtype = xp.bool if dh.func_returns_bool[op] else in_dtype - for in_stype in dh.dtype_to_scalars[in_dtype]: - yield pytest.param( - op, - f'x {symbol} s', - in_dtype, - in_stype, - out_dtype, - id=make_id(op, (in_dtype, in_stype), out_dtype), - ) +op_scalar_params: List[Tuple[str, str, DT, ScalarType, DT]] = [] +for op, symbol in dh.binary_op_to_symbol.items(): + if op == '__matmul__': + continue + for in_dtype in dh.func_in_dtypes[op]: + out_dtype = xp.bool if dh.func_returns_bool[op] else in_dtype + for in_stype in dh.dtype_to_scalars[in_dtype]: + p = pytest.param( + op, + f'x {symbol} s', + in_dtype, + in_stype, + out_dtype, + id=make_id(op, (in_dtype, in_stype), out_dtype), + ) + op_scalar_params.append(p) -@pytest.mark.parametrize( - 'op, expr, in_dtype, in_stype, out_dtype', gen_op_scalar_params() -) +@pytest.mark.parametrize('op, expr, in_dtype, in_stype, out_dtype', op_scalar_params) @given(data=st.data()) def test_op_scalar_promotion(op, expr, in_dtype, in_stype, out_dtype, data): x_filter = filters[op] @@ -249,22 +254,23 @@ def test_op_scalar_promotion(op, expr, in_dtype, in_stype, out_dtype, data): assert out.dtype == out_dtype, f'{out.dtype=!s}, but should be {out_dtype}' -def gen_inplace_scalar_params() -> Iterator[Tuple[str, str, DT, ScalarType]]: - for op, symbol in dh.inplace_op_to_symbol.items(): - if op == '__imatmul__': - continue - for dtype in dh.func_in_dtypes[op]: - for in_stype in dh.dtype_to_scalars[dtype]: - yield pytest.param( - op, - f'x {symbol} s', - dtype, - in_stype, - id=make_id(op, (dtype, in_stype), dtype), - ) +inplace_scalar_params: List[Tuple[str, str, DT, ScalarType]] = [] +for op, symbol in dh.inplace_op_to_symbol.items(): + if op == '__imatmul__': + continue + for dtype in dh.func_in_dtypes[op]: + for in_stype in dh.dtype_to_scalars[dtype]: + p = pytest.param( + op, + f'x {symbol} s', + dtype, + in_stype, + id=make_id(op, (dtype, in_stype), dtype), + ) + inplace_scalar_params.append(p) -@pytest.mark.parametrize('op, expr, dtype, in_stype', gen_inplace_scalar_params()) +@pytest.mark.parametrize('op, expr, dtype, in_stype', inplace_scalar_params) @given(data=st.data()) def test_inplace_op_scalar_promotion(op, expr, dtype, in_stype, data): x_filter = filters[op]