Skip to content

Commit abafee1

Browse files
committed
More flexible size assertions for int arrays in test_arange
1 parent 8a7e873 commit abafee1

File tree

1 file changed

+24
-17
lines changed

1 file changed

+24
-17
lines changed

array_api_tests/test_creation_functions.py

Lines changed: 24 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -217,25 +217,32 @@ def test_arange(dtype, data):
217217
assert out.dtype == dtype
218218
assert out.ndim == 1, f"{out.ndim=}, but should be 1 [linspace()]"
219219
f_func = f"[linspace({start=}, {stop=}, {step=})]"
220+
# We check size is roughly as expected to avoid edge cases e.g.
221+
#
222+
# >>> xp.arange(2, step=0.333333333333333)
223+
# [0.0, 0.33, 0.66, 1.0, 1.33, 1.66, 2.0]
224+
# >>> xp.arange(2, step=0.3333333333333333)
225+
# [0.0, 0.33, 0.66, 1.0, 1.33, 1.66]
226+
#
227+
# >>> start, stop, step = 0, 108086391056891901, 1080863910568919
228+
# >>> x = xp.arange(start, stop, step, dtype=xp.uint64)
229+
# >>> x.size
230+
# 100
231+
# >>> r = range(start, stop, step)
232+
# >>> len(r)
233+
# 101
234+
#
235+
min_size = math.floor(size * 0.9)
236+
max_size = max(math.ceil(size * 1.1), 1)
237+
assert (
238+
min_size <= out.size <= max_size
239+
), f"{out.size=}, but should be roughly {size} {f_func}"
220240
if dh.is_int_dtype(_dtype):
221-
assert out.size == size, f"{out.size=}, but should be {size} {f_func}"
222-
else:
223-
# We check size is roughly as expected to avoid edge cases e.g.
224-
#
225-
# >>> xp.arange(2, step=0.333333333333333)
226-
# [0.0, 0.33, 0.66, 1.0, 1.33, 1.66, 2.0]
227-
# >>> xp.arange(2, step=0.3333333333333333)
228-
# [0.0, 0.33, 0.66, 1.0, 1.33, 1.66]
229-
#
230-
min_size = math.floor(size * 0.9)
231-
max_size = max(math.ceil(size * 1.1), 1)
232-
assert (
233-
min_size <= out.size <= max_size
234-
), f"{out.size=}, but should be roughly {size} {f_func}"
235-
assume(out.size == size)
236-
if dh.is_int_dtype(_dtype):
237-
ah.assert_exactly_equal(out, ah.asarray(list(r), dtype=_dtype))
241+
elements = list(r)
242+
assume(out.size == len(elements))
243+
ah.assert_exactly_equal(out, ah.asarray(elements, dtype=_dtype))
238244
else:
245+
assume(out.size == size)
239246
if out.size > 0:
240247
assert ah.equal(
241248
out[0], ah.asarray(_start, dtype=out.dtype)

0 commit comments

Comments
 (0)