Skip to content

Commit 788b5ec

Browse files
committed
hh.reject_overflow()
1 parent 1a73804 commit 788b5ec

File tree

5 files changed

+47
-21
lines changed

5 files changed

+47
-21
lines changed

array_api_tests/hypothesis_helpers.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1+
import re
12
import itertools
3+
from contextlib import contextmanager
24
from functools import reduce
35
from math import sqrt
46
from operator import mul
@@ -477,3 +479,14 @@ def axes(ndim: int) -> SearchStrategy[Optional[Union[int, Shape]]]:
477479
axes_strats.append(integers(-ndim, ndim - 1))
478480
axes_strats.append(xps.valid_tuple_axes(ndim))
479481
return one_of(axes_strats)
482+
483+
484+
@contextmanager
485+
def reject_overflow():
486+
try:
487+
yield
488+
except Exception as e:
489+
if isinstance(e, OverflowError) or re.search("[Oo]verflow", str(e)):
490+
reject()
491+
else:
492+
raise e

array_api_tests/meta/test_hypothesis_helpers.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
from math import prod
2+
from typing import Type
23

34
import pytest
45
from hypothesis import given, settings
56
from hypothesis import strategies as st
7+
from hypothesis.errors import Unsatisfiable
68

79
from .. import _array_module as xp
810
from .. import array_helpers as ah
@@ -144,3 +146,27 @@ def test_symmetric_matrices(m, dtype, finite):
144146
def test_positive_definite_matrices(m, dtype):
145147
assert m.dtype == dtype
146148
# TODO: Test that it actually is positive definite
149+
150+
151+
def make_raising_func(cls: Type[Exception], msg: str):
152+
def raises():
153+
raise cls(msg)
154+
155+
return raises
156+
157+
@pytest.mark.parametrize(
158+
"func",
159+
[
160+
make_raising_func(OverflowError, "foo"),
161+
make_raising_func(RuntimeError, "Overflow when unpacking long"),
162+
make_raising_func(Exception, "Got an overflow"),
163+
]
164+
)
165+
def test_reject_overflow(func):
166+
@given(data=st.data())
167+
def test_case(data):
168+
with hh.reject_overflow():
169+
func()
170+
171+
with pytest.raises(Unsatisfiable):
172+
test_case()

array_api_tests/meta/test_utils.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import pytest
2-
from hypothesis import given, reject
2+
from hypothesis import given
33
from hypothesis import strategies as st
44

55
from .. import _array_module as xp
@@ -105,10 +105,8 @@ def test_fmt_idx(idx, expected):
105105

106106
@given(x=st.integers(), dtype=xps.unsigned_integer_dtypes() | xps.integer_dtypes())
107107
def test_int_to_dtype(x, dtype):
108-
try:
108+
with hh.reject_overflow():
109109
d = xp.asarray(x, dtype=dtype)
110-
except OverflowError:
111-
reject()
112110
assert mock_int_dtype(x, dtype) == d
113111

114112

array_api_tests/test_operators_and_elementwise_functions.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from typing import Callable, List, NamedTuple, Optional, Sequence, TypeVar, Union
99

1010
import pytest
11-
from hypothesis import assume, given, reject
11+
from hypothesis import assume, given
1212
from hypothesis import strategies as st
1313

1414
from . import _array_module as xp, api_version
@@ -740,10 +740,8 @@ def test_add(ctx, data):
740740
left = data.draw(ctx.left_strat, label=ctx.left_sym)
741741
right = data.draw(ctx.right_strat, label=ctx.right_sym)
742742

743-
try:
743+
with hh.reject_overflow():
744744
res = ctx.func(left, right)
745-
except OverflowError:
746-
reject()
747745

748746
binary_param_assert_dtype(ctx, left, right, res)
749747
binary_param_assert_shape(ctx, left, right, res)
@@ -1327,10 +1325,8 @@ def test_pow(ctx, data):
13271325
if dh.is_int_dtype(right.dtype):
13281326
assume(xp.all(right >= 0))
13291327

1330-
try:
1328+
with hh.reject_overflow():
13311329
res = ctx.func(left, right)
1332-
except OverflowError:
1333-
reject()
13341330

13351331
binary_param_assert_dtype(ctx, left, right, res)
13361332
binary_param_assert_shape(ctx, left, right, res)
@@ -1425,10 +1421,8 @@ def test_subtract(ctx, data):
14251421
left = data.draw(ctx.left_strat, label=ctx.left_sym)
14261422
right = data.draw(ctx.right_strat, label=ctx.right_sym)
14271423

1428-
try:
1424+
with hh.reject_overflow():
14291425
res = ctx.func(left, right)
1430-
except OverflowError:
1431-
reject()
14321426

14331427
binary_param_assert_dtype(ctx, left, right, res)
14341428
binary_param_assert_shape(ctx, left, right, res)

array_api_tests/test_statistical_functions.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import pytest
66
from hypothesis import assume, given
77
from hypothesis import strategies as st
8-
from hypothesis.control import reject
98

109
from . import _array_module as xp
1110
from . import dtype_helpers as dh
@@ -127,10 +126,8 @@ def test_prod(x, data):
127126
)
128127
keepdims = kw.get("keepdims", False)
129128

130-
try:
129+
with hh.reject_overflow():
131130
out = xp.prod(x, **kw)
132-
except OverflowError:
133-
reject()
134131

135132
dtype = kw.get("dtype", None)
136133
if dtype is None:
@@ -234,10 +231,8 @@ def test_sum(x, data):
234231
)
235232
keepdims = kw.get("keepdims", False)
236233

237-
try:
234+
with hh.reject_overflow():
238235
out = xp.sum(x, **kw)
239-
except OverflowError:
240-
reject()
241236

242237
dtype = kw.get("dtype", None)
243238
if dtype is None:

0 commit comments

Comments
 (0)