Skip to content

Commit eab08d9

Browse files
committed
Generate NaNs and infs for scalar promotion tests
1 parent 06fc784 commit eab08d9

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

array_api_tests/test_type_promotion.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,6 @@ def test_inplace_operator_returns_array_with_correct_dtype(
209209
assert x1.dtype == out_dtype, f'{x1.dtype=!s}, but should be {out_dtype}'
210210

211211

212-
finite_kw = {'allow_nan': False, 'allow_infinity': False}
213212
ScalarType = Union[Type[bool], Type[int], Type[float]]
214213

215214

@@ -239,7 +238,8 @@ def gen_op_scalar_params() -> Iterator[Tuple[str, DT, ScalarType, DT, Callable]]
239238
def test_binary_operator_promotes_python_scalars(
240239
expr, in_dtype, in_stype, out_dtype, x_filter, data
241240
):
242-
s = data.draw(xps.from_dtype(in_dtype, **finite_kw).map(in_stype), label='scalar')
241+
kw = {k: in_stype is float for k in ('allow_nan', 'allow_infinity')}
242+
s = data.draw(xps.from_dtype(in_dtype, **kw).map(in_stype), label='scalar')
243243
x = data.draw(
244244
xps.arrays(dtype=in_dtype, shape=hh.shapes).filter(x_filter), label='x'
245245
)
@@ -271,7 +271,8 @@ def gen_inplace_scalar_params() -> Iterator[Tuple[str, DT, ScalarType, Callable]
271271
def test_inplace_operator_promotes_python_scalars(
272272
expr, dtype, in_stype, x_filter, data
273273
):
274-
s = data.draw(xps.from_dtype(dtype, **finite_kw).map(in_stype), label='scalar')
274+
kw = {k: in_stype is float for k in ('allow_nan', 'allow_infinity')}
275+
s = data.draw(xps.from_dtype(dtype, **kw).map(in_stype), label='scalar')
275276
x = data.draw(xps.arrays(dtype=dtype, shape=hh.shapes).filter(x_filter), label='x')
276277
locals_ = {'x': x, 's': s}
277278
try:

0 commit comments

Comments
 (0)