Skip to content

hh.reject_overflow() #182

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions array_api_tests/hypothesis_helpers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import re
import itertools
from contextlib import contextmanager
from functools import reduce
from math import sqrt
from operator import mul
Expand Down Expand Up @@ -477,3 +479,14 @@ def axes(ndim: int) -> SearchStrategy[Optional[Union[int, Shape]]]:
axes_strats.append(integers(-ndim, ndim - 1))
axes_strats.append(xps.valid_tuple_axes(ndim))
return one_of(axes_strats)


@contextmanager
def reject_overflow():
try:
yield
except Exception as e:
if isinstance(e, OverflowError) or re.search("[Oo]verflow", str(e)):
reject()
else:
raise e
26 changes: 26 additions & 0 deletions array_api_tests/meta/test_hypothesis_helpers.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from math import prod
from typing import Type

import pytest
from hypothesis import given, settings
from hypothesis import strategies as st
from hypothesis.errors import Unsatisfiable

from .. import _array_module as xp
from .. import array_helpers as ah
Expand Down Expand Up @@ -144,3 +146,27 @@ def test_symmetric_matrices(m, dtype, finite):
def test_positive_definite_matrices(m, dtype):
assert m.dtype == dtype
# TODO: Test that it actually is positive definite


def make_raising_func(cls: Type[Exception], msg: str):
def raises():
raise cls(msg)

return raises

@pytest.mark.parametrize(
"func",
[
make_raising_func(OverflowError, "foo"),
make_raising_func(RuntimeError, "Overflow when unpacking long"),
make_raising_func(Exception, "Got an overflow"),
]
)
def test_reject_overflow(func):
@given(data=st.data())
def test_case(data):
with hh.reject_overflow():
func()

with pytest.raises(Unsatisfiable):
test_case()
6 changes: 2 additions & 4 deletions array_api_tests/meta/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import pytest
from hypothesis import given, reject
from hypothesis import given
from hypothesis import strategies as st

from .. import _array_module as xp
Expand Down Expand Up @@ -105,10 +105,8 @@ def test_fmt_idx(idx, expected):

@given(x=st.integers(), dtype=xps.unsigned_integer_dtypes() | xps.integer_dtypes())
def test_int_to_dtype(x, dtype):
try:
with hh.reject_overflow():
d = xp.asarray(x, dtype=dtype)
except OverflowError:
reject()
assert mock_int_dtype(x, dtype) == d


Expand Down
14 changes: 4 additions & 10 deletions array_api_tests/test_operators_and_elementwise_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from typing import Callable, List, NamedTuple, Optional, Sequence, TypeVar, Union

import pytest
from hypothesis import assume, given, reject
from hypothesis import assume, given
from hypothesis import strategies as st

from . import _array_module as xp, api_version
Expand Down Expand Up @@ -740,10 +740,8 @@ def test_add(ctx, data):
left = data.draw(ctx.left_strat, label=ctx.left_sym)
right = data.draw(ctx.right_strat, label=ctx.right_sym)

try:
with hh.reject_overflow():
res = ctx.func(left, right)
except OverflowError:
reject()

binary_param_assert_dtype(ctx, left, right, res)
binary_param_assert_shape(ctx, left, right, res)
Expand Down Expand Up @@ -1327,10 +1325,8 @@ def test_pow(ctx, data):
if dh.is_int_dtype(right.dtype):
assume(xp.all(right >= 0))

try:
with hh.reject_overflow():
res = ctx.func(left, right)
except OverflowError:
reject()

binary_param_assert_dtype(ctx, left, right, res)
binary_param_assert_shape(ctx, left, right, res)
Expand Down Expand Up @@ -1425,10 +1421,8 @@ def test_subtract(ctx, data):
left = data.draw(ctx.left_strat, label=ctx.left_sym)
right = data.draw(ctx.right_strat, label=ctx.right_sym)

try:
with hh.reject_overflow():
res = ctx.func(left, right)
except OverflowError:
reject()

binary_param_assert_dtype(ctx, left, right, res)
binary_param_assert_shape(ctx, left, right, res)
Expand Down
9 changes: 2 additions & 7 deletions array_api_tests/test_statistical_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import pytest
from hypothesis import assume, given
from hypothesis import strategies as st
from hypothesis.control import reject

from . import _array_module as xp
from . import dtype_helpers as dh
Expand Down Expand Up @@ -127,10 +126,8 @@ def test_prod(x, data):
)
keepdims = kw.get("keepdims", False)

try:
with hh.reject_overflow():
out = xp.prod(x, **kw)
except OverflowError:
reject()

dtype = kw.get("dtype", None)
if dtype is None:
Expand Down Expand Up @@ -234,10 +231,8 @@ def test_sum(x, data):
)
keepdims = kw.get("keepdims", False)

try:
with hh.reject_overflow():
out = xp.sum(x, **kw)
except OverflowError:
reject()

dtype = kw.get("dtype", None)
if dtype is None:
Expand Down