Skip to content

Commit b970435

Browse files
committed
Try to ignore overflow scenarios in prod and sum tests
1 parent 3f28c60 commit b970435

File tree

1 file changed

+11
-4
lines changed

1 file changed

+11
-4
lines changed

array_api_tests/test_statistical_functions.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from hypothesis import assume, given
66
from hypothesis import strategies as st
7+
from hypothesis.control import reject
78

89
from . import _array_module as xp
910
from . import array_helpers as ah
@@ -202,7 +203,10 @@ def test_prod(x, data):
202203
label="kw",
203204
)
204205

205-
out = xp.prod(x, **kw)
206+
try:
207+
out = xp.prod(x, **kw)
208+
except OverflowError:
209+
reject()
206210

207211
dtype = kw.get("dtype", None)
208212
if dtype is None:
@@ -232,7 +236,7 @@ def test_prod(x, data):
232236
scalar_type = dh.get_scalar_type(out.dtype)
233237
for indices, out_idx in zip(axes_ndindex(x.shape, _axes), ah.ndindex(out.shape)):
234238
prod = scalar_type(out[out_idx])
235-
assume(not math.isinf(prod))
239+
assume(math.isfinite(prod))
236240
elements = []
237241
for idx in indices:
238242
s = scalar_type(x[idx])
@@ -297,7 +301,10 @@ def test_sum(x, data):
297301
label="kw",
298302
)
299303

300-
out = xp.sum(x, **kw)
304+
try:
305+
out = xp.sum(x, **kw)
306+
except OverflowError:
307+
reject()
301308

302309
dtype = kw.get("dtype", None)
303310
if dtype is None:
@@ -327,7 +334,7 @@ def test_sum(x, data):
327334
scalar_type = dh.get_scalar_type(out.dtype)
328335
for indices, out_idx in zip(axes_ndindex(x.shape, _axes), ah.ndindex(out.shape)):
329336
sum_ = scalar_type(out[out_idx])
330-
assume(not math.isinf(sum_))
337+
assume(math.isfinite(sum_))
331338
elements = []
332339
for idx in indices:
333340
s = scalar_type(x[idx])

0 commit comments

Comments
 (0)