Skip to content

Commit e40cac8

Browse files
committed
Reject on OverflowError
1 parent 198df7e commit e40cac8

File tree

1 file changed

+16
-8
lines changed

1 file changed

+16
-8
lines changed

array_api_tests/test_type_promotion.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from typing import Iterator, TypeVar, Tuple, Callable, Type, Union
66

77
import pytest
8-
from hypothesis import assume, given
8+
from hypothesis import assume, given, reject
99
from hypothesis import strategies as st
1010

1111
from . import _array_module as xp
@@ -81,7 +81,7 @@ def test_func_returns_array_with_correct_dtype(
8181
x = data.draw(
8282
xps.arrays(dtype=in_dtypes[0], shape=hh.shapes).filter(x_filter), label='x'
8383
)
84-
out = func(x)
84+
arrays = [x]
8585
else:
8686
arrays = []
8787
shapes = data.draw(
@@ -92,7 +92,10 @@ def test_func_returns_array_with_correct_dtype(
9292
xps.arrays(dtype=dtype, shape=shape).filter(x_filter), label=f'x{i}'
9393
)
9494
arrays.append(x)
95+
try:
9596
out = func(*arrays)
97+
except OverflowError:
98+
reject()
9699
assert out.dtype == out_dtype, f'{out.dtype=!s}, but should be {out_dtype}'
97100

98101

@@ -144,21 +147,23 @@ def gen_op_params() -> Iterator[Tuple[str, Tuple[DT, ...], DT, Callable]]:
144147
def test_operator_returns_array_with_correct_dtype(
145148
expr, in_dtypes, out_dtype, x_filter, data
146149
):
150+
locals_ = {}
147151
if len(in_dtypes) == 1:
148-
x = data.draw(
152+
locals_['x'] = data.draw(
149153
xps.arrays(dtype=in_dtypes[0], shape=hh.shapes).filter(x_filter), label='x'
150154
)
151-
out = eval(expr, {'x': x})
152155
else:
153-
locals_ = {}
154156
shapes = data.draw(
155157
hh.mutually_broadcastable_shapes(len(in_dtypes)), label='shapes'
156158
)
157159
for i, (dtype, shape) in enumerate(zip(in_dtypes, shapes), 1):
158160
locals_[f'x{i}'] = data.draw(
159161
xps.arrays(dtype=dtype, shape=shape).filter(x_filter), label=f'x{i}'
160162
)
163+
try:
161164
out = eval(expr, locals_)
165+
except OverflowError:
166+
reject()
162167
assert out.dtype == out_dtype, f'{out.dtype=!s}, but should be {out_dtype}'
163168

164169

@@ -197,7 +202,10 @@ def test_inplace_operator_returns_array_with_correct_dtype(
197202
xps.arrays(dtype=in_dtypes[1], shape=shapes[1]).filter(x_filter), label='x2'
198203
)
199204
locals_ = {'x1': x1, 'x2': x2}
200-
exec(expr, locals_)
205+
try:
206+
exec(expr, locals_)
207+
except OverflowError:
208+
reject()
201209
x1 = locals_['x1']
202210
assert x1.dtype == out_dtype, f'{x1.dtype=!s}, but should be {out_dtype}'
203211

@@ -239,7 +247,7 @@ def test_binary_operator_promotes_python_scalars(
239247
try:
240248
out = eval(expr, {'x': x, 's': s})
241249
except OverflowError:
242-
assume(False)
250+
reject()
243251
assert out.dtype == out_dtype, f'{out.dtype=!s}, but should be {out_dtype}'
244252

245253

@@ -271,7 +279,7 @@ def test_inplace_operator_promotes_python_scalars(
271279
try:
272280
exec(expr, locals_)
273281
except OverflowError:
274-
assume(False)
282+
reject()
275283
x = locals_['x']
276284
assert x.dtype == dtype, f'{x.dtype=!s}, but should be {dtype}'
277285

0 commit comments

Comments
 (0)