From ba6995e3f262b64a0f7199f8d709a7847ba9dc1b Mon Sep 17 00:00:00 2001 From: Matthew Date: Thu, 21 Oct 2021 12:50:36 +0100 Subject: [PATCH] Bound `test_full` fill values to default dtypes --- array_api_tests/dtype_helpers.py | 16 ++++++++++++++++ array_api_tests/test_creation_functions.py | 14 ++++++++++++-- 2 files changed, 28 insertions(+), 2 deletions(-) diff --git a/array_api_tests/dtype_helpers.py b/array_api_tests/dtype_helpers.py index b81edb7b..a60abef1 100644 --- a/array_api_tests/dtype_helpers.py +++ b/array_api_tests/dtype_helpers.py @@ -1,6 +1,8 @@ +from warnings import warn from typing import NamedTuple from . import _array_module as xp +from ._array_module import _UndefinedStub __all__ = [ @@ -16,6 +18,8 @@ 'is_int_dtype', 'is_float_dtype', 'dtype_ranges', + 'default_int', + 'default_float', 'promotion_table', 'dtype_nbits', 'dtype_signed', @@ -84,6 +88,18 @@ class MinMax(NamedTuple): } +if isinstance(xp.asarray, _UndefinedStub): + default_int = xp.int32 + default_float = xp.float32 + warn( + 'array module does not have attribute asarray. ' + 'default int is assumed int32, default float is assumed float32' + ) +else: + default_int = xp.asarray(int()).dtype + default_float = xp.asarray(float()).dtype + + _numeric_promotions = { # ints (xp.int8, xp.int8): xp.int8, diff --git a/array_api_tests/test_creation_functions.py b/array_api_tests/test_creation_functions.py index 302bef49..58575a8a 100644 --- a/array_api_tests/test_creation_functions.py +++ b/array_api_tests/test_creation_functions.py @@ -12,7 +12,7 @@ from . import xps from hypothesis import assume, given -from hypothesis.strategies import integers, floats, one_of, none, booleans, just, shared, composite +from hypothesis.strategies import integers, floats, one_of, none, booleans, just, shared, composite, SearchStrategy int_range = integers(-MAX_ARRAY_SIZE, MAX_ARRAY_SIZE) @@ -128,10 +128,20 @@ def test_eye(n_rows, n_cols, k, dtype): assert a[i, j] == 0, "eye() did not produce a 0 off the diagonal" +default_unsafe_dtypes = [xp.uint64] +if dh.default_int == xp.int32: + default_unsafe_dtypes.extend([xp.uint32, xp.int64]) +if dh.default_float == xp.float32: + default_unsafe_dtypes.append(xp.float64) +default_safe_scalar_dtypes: SearchStrategy = xps.scalar_dtypes().filter( + lambda d: d not in default_unsafe_dtypes +) + + @composite def full_fill_values(draw): kw = draw(shared(kwargs(dtype=none() | xps.scalar_dtypes()), key="full_kw")) - dtype = kw.get("dtype", None) or draw(xps.scalar_dtypes()) + dtype = kw.get("dtype", None) or draw(default_safe_scalar_dtypes) return draw(xps.from_dtype(dtype))