Skip to content

Commit 4d22300

Browse files
committed
Prevent hypothesis testing boundary numbers
1 parent 73c47d8 commit 4d22300

15 files changed

+203
-149
lines changed

array_api_tests/hypothesis_helpers.py

Lines changed: 67 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,16 @@
11
import re
22
import itertools
33
from contextlib import contextmanager
4-
from functools import reduce
5-
from math import sqrt
4+
from functools import reduce, wraps
5+
import math
66
from operator import mul
7-
from typing import Any, List, NamedTuple, Optional, Sequence, Tuple, Union
7+
import struct
8+
from typing import Any, List, Mapping, NamedTuple, Optional, Sequence, Tuple, Union
89

910
from hypothesis import assume, reject
1011
from hypothesis.strategies import (SearchStrategy, booleans, composite, floats,
1112
integers, just, lists, none, one_of,
12-
sampled_from, shared)
13+
sampled_from, shared, builds)
1314

1415
from . import _array_module as xp, api_version
1516
from . import dtype_helpers as dh
@@ -20,7 +21,60 @@
2021
from ._array_module import broadcast_to, eye, float32, float64, full
2122
from .stubs import category_to_funcs
2223
from .pytest_helpers import nargs
23-
from .typing import Array, DataType, Shape
24+
from .typing import Array, DataType, Scalar, Shape
25+
26+
27+
def _float32ify(n: Union[int, float]) -> float:
28+
n = float(n)
29+
return struct.unpack("!f", struct.pack("!f", n))[0]
30+
31+
32+
@wraps(xps.from_dtype)
33+
def from_dtype(dtype, **kwargs) -> SearchStrategy[Scalar]:
34+
"""xps.from_dtype() without the crazy large numbers."""
35+
if dtype == xp.bool:
36+
return xps.from_dtype(dtype, **kwargs)
37+
38+
if dtype in dh.complex_dtypes:
39+
component_dtype = dh.dtype_components[dtype]
40+
else:
41+
component_dtype = dtype
42+
43+
min_, max_ = dh.dtype_ranges[component_dtype]
44+
45+
if "min_value" not in kwargs.keys() and min_ != 0:
46+
assert min_ < 0 # sanity check
47+
min_value = -1 * math.floor(math.sqrt(abs(min_)))
48+
if component_dtype == xp.float32:
49+
min_value = _float32ify(min_value)
50+
kwargs["min_value"] = min_value
51+
if "max_value" not in kwargs.keys():
52+
assert max_ > 0 # sanity check
53+
max_value = math.floor(math.sqrt(max_))
54+
if component_dtype == xp.float32:
55+
max_value = _float32ify(max_value)
56+
kwargs["max_value"] = max_value
57+
58+
if dtype in dh.complex_dtypes:
59+
component_strat = xps.from_dtype(dh.dtype_components[dtype], **kwargs)
60+
return builds(complex, component_strat, component_strat)
61+
else:
62+
return xps.from_dtype(dtype, **kwargs)
63+
64+
65+
@wraps(xps.arrays)
66+
def arrays(dtype, *args, elements=None, **kwargs) -> SearchStrategy[Array]:
67+
"""xps.arrays() without the crazy large numbers."""
68+
if isinstance(dtype, SearchStrategy):
69+
return dtype.flatmap(lambda d: arrays(d, *args, elements=elements, **kwargs))
70+
71+
if elements is None:
72+
elements = from_dtype(dtype)
73+
elif isinstance(elements, Mapping):
74+
elements = from_dtype(dtype, **elements)
75+
76+
return xps.arrays(dtype, *args, elements=elements, **kwargs)
77+
2478

2579
_dtype_categories = [(xp.bool,), dh.uint_dtypes, dh.int_dtypes, dh.real_float_dtypes, dh.complex_dtypes]
2680
_sorted_dtypes = [d for category in _dtype_categories for d in category]
@@ -145,7 +199,7 @@ def all_floating_dtypes() -> SearchStrategy[DataType]:
145199
# Limit the total size of an array shape
146200
MAX_ARRAY_SIZE = 10000
147201
# Size to use for 2-dim arrays
148-
SQRT_MAX_ARRAY_SIZE = int(sqrt(MAX_ARRAY_SIZE))
202+
SQRT_MAX_ARRAY_SIZE = int(math.sqrt(MAX_ARRAY_SIZE))
149203

150204
# np.prod and others have overflow and math.prod is Python 3.8+ only
151205
def prod(seq):
@@ -181,7 +235,7 @@ def matrix_shapes(draw, stack_shapes=shapes()):
181235

182236
@composite
183237
def finite_matrices(draw, shape=matrix_shapes()):
184-
return draw(xps.arrays(dtype=xps.floating_dtypes(),
238+
return draw(arrays(dtype=xps.floating_dtypes(),
185239
shape=shape,
186240
elements=dict(allow_nan=False,
187241
allow_infinity=False)))
@@ -190,7 +244,7 @@ def finite_matrices(draw, shape=matrix_shapes()):
190244
# Should we set a max_value here?
191245
_rtol_float_kw = dict(allow_nan=False, allow_infinity=False, min_value=0)
192246
rtols = one_of(floats(**_rtol_float_kw),
193-
xps.arrays(dtype=xps.floating_dtypes(),
247+
arrays(dtype=xps.floating_dtypes(),
194248
shape=rtol_shared_matrix_shapes.map(lambda shape: shape[:-2]),
195249
elements=_rtol_float_kw))
196250

@@ -233,7 +287,7 @@ def symmetric_matrices(draw, dtypes=xps.floating_dtypes(), finite=True):
233287
if not isinstance(finite, bool):
234288
finite = draw(finite)
235289
elements = {'allow_nan': False, 'allow_infinity': False} if finite else None
236-
a = draw(xps.arrays(dtype=dtype, shape=shape, elements=elements))
290+
a = draw(arrays(dtype=dtype, shape=shape, elements=elements))
237291
upper = xp.triu(a)
238292
lower = xp.triu(a, k=1).mT
239293
return upper + lower
@@ -256,7 +310,7 @@ def invertible_matrices(draw, dtypes=xps.floating_dtypes(), stack_shapes=shapes(
256310
n = draw(integers(0, SQRT_MAX_ARRAY_SIZE),)
257311
stack_shape = draw(stack_shapes)
258312
shape = stack_shape + (n, n)
259-
d = draw(xps.arrays(dtypes, shape=n*prod(stack_shape),
313+
d = draw(arrays(dtypes, shape=n*prod(stack_shape),
260314
elements=dict(allow_nan=False, allow_infinity=False)))
261315
# Functions that require invertible matrices may do anything when it is
262316
# singular, including raising an exception, so we make sure the diagonals
@@ -282,7 +336,7 @@ def two_broadcastable_shapes(draw):
282336
sizes = integers(0, MAX_ARRAY_SIZE)
283337
sqrt_sizes = integers(0, SQRT_MAX_ARRAY_SIZE)
284338

285-
numeric_arrays = xps.arrays(
339+
numeric_arrays = arrays(
286340
dtype=shared(xps.floating_dtypes(), key='dtypes'),
287341
shape=shared(xps.array_shapes(), key='shapes'),
288342
)
@@ -407,11 +461,11 @@ def two_mutual_arrays(
407461
assert len(dtypes) > 0 # sanity check
408462
mutual_dtypes = shared(mutually_promotable_dtypes(dtypes=dtypes))
409463
mutual_shapes = shared(two_shapes)
410-
arrays1 = xps.arrays(
464+
arrays1 = arrays(
411465
dtype=mutual_dtypes.map(lambda pair: pair[0]),
412466
shape=mutual_shapes.map(lambda pair: pair[0]),
413467
)
414-
arrays2 = xps.arrays(
468+
arrays2 = arrays(
415469
dtype=mutual_dtypes.map(lambda pair: pair[1]),
416470
shape=mutual_shapes.map(lambda pair: pair[1]),
417471
)

array_api_tests/test_array_object.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def scalar_objects(
2424
) -> st.SearchStrategy[Union[Scalar, List[Scalar]]]:
2525
"""Generates scalars or nested sequences which are valid for xp.asarray()"""
2626
size = math.prod(shape)
27-
return st.lists(xps.from_dtype(dtype), min_size=size, max_size=size).map(
27+
return st.lists(hh.from_dtype(dtype), min_size=size, max_size=size).map(
2828
lambda l: sh.reshape(l, shape)
2929
)
3030

@@ -123,10 +123,10 @@ def test_setitem(shape, dtypes, data):
123123
key = data.draw(xps.indices(shape=shape), label="key")
124124
_key = normalise_key(key, shape)
125125
axes_indices, out_shape = get_indexed_axes_and_out_shape(_key, shape)
126-
value_strat = xps.arrays(dtype=dtypes.result_dtype, shape=out_shape)
126+
value_strat = hh.arrays(dtype=dtypes.result_dtype, shape=out_shape)
127127
if out_shape == ():
128128
# We can pass scalars if we're only indexing one element
129-
value_strat |= xps.from_dtype(dtypes.result_dtype)
129+
value_strat |= hh.from_dtype(dtypes.result_dtype)
130130
value = data.draw(value_strat, label="value")
131131

132132
res = xp.asarray(x, copy=True)
@@ -157,15 +157,15 @@ def test_setitem(shape, dtypes, data):
157157
@pytest.mark.data_dependent_shapes
158158
@given(hh.shapes(), st.data())
159159
def test_getitem_masking(shape, data):
160-
x = data.draw(xps.arrays(xps.scalar_dtypes(), shape=shape), label="x")
160+
x = data.draw(hh.arrays(xps.scalar_dtypes(), shape=shape), label="x")
161161
mask_shapes = st.one_of(
162162
st.sampled_from([x.shape, ()]),
163163
st.lists(st.booleans(), min_size=x.ndim, max_size=x.ndim).map(
164164
lambda l: tuple(s if b else 0 for s, b in zip(x.shape, l))
165165
),
166166
hh.shapes(),
167167
)
168-
key = data.draw(xps.arrays(dtype=xp.bool, shape=mask_shapes), label="key")
168+
key = data.draw(hh.arrays(dtype=xp.bool, shape=mask_shapes), label="key")
169169

170170
if key.ndim > x.ndim or not all(
171171
ks in (xs, 0) for xs, ks in zip(x.shape, key.shape)
@@ -201,10 +201,10 @@ def test_getitem_masking(shape, data):
201201

202202
@given(hh.shapes(), st.data())
203203
def test_setitem_masking(shape, data):
204-
x = data.draw(xps.arrays(xps.scalar_dtypes(), shape=shape), label="x")
205-
key = data.draw(xps.arrays(dtype=xp.bool, shape=shape), label="key")
204+
x = data.draw(hh.arrays(xps.scalar_dtypes(), shape=shape), label="x")
205+
key = data.draw(hh.arrays(dtype=xp.bool, shape=shape), label="key")
206206
value = data.draw(
207-
xps.from_dtype(x.dtype) | xps.arrays(dtype=x.dtype, shape=()), label="value"
207+
hh.from_dtype(x.dtype) | hh.arrays(dtype=x.dtype, shape=()), label="value"
208208
)
209209

210210
res = xp.asarray(x, copy=True)
@@ -263,7 +263,7 @@ def test_scalar_casting(method_name, dtype_name, stype, data):
263263
dtype = getattr(_xp, dtype_name)
264264
except AttributeError as e:
265265
pytest.skip(str(e))
266-
x = data.draw(xps.arrays(dtype, shape=()), label="x")
266+
x = data.draw(hh.arrays(dtype, shape=()), label="x")
267267
method = getattr(x, method_name)
268268
out = method()
269269
assert isinstance(

array_api_tests/test_creation_functions.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -86,8 +86,8 @@ def test_arange(dtype, data):
8686
start = data.draw(reals(), label="start")
8787
stop = data.draw(reals() | st.none(), label="stop")
8888
else:
89-
start = data.draw(xps.from_dtype(dtype), label="start")
90-
stop = data.draw(xps.from_dtype(dtype), label="stop")
89+
start = data.draw(hh.from_dtype(dtype), label="start")
90+
stop = data.draw(hh.from_dtype(dtype), label="stop")
9191
if stop is None:
9292
_start = 0
9393
_stop = start
@@ -107,9 +107,9 @@ def test_arange(dtype, data):
107107
step_strats = []
108108
if dtype in dh.int_dtypes:
109109
step_min = min(math.floor(-tol), -1)
110-
step_strats.append(xps.from_dtype(dtype, max_value=step_min))
110+
step_strats.append(hh.from_dtype(dtype, max_value=step_min))
111111
step_max = max(math.ceil(tol), 1)
112-
step_strats.append(xps.from_dtype(dtype, min_value=step_max))
112+
step_strats.append(hh.from_dtype(dtype, min_value=step_max))
113113
step = data.draw(st.one_of(step_strats), label="step")
114114
assert step != 0, "step must not equal 0" # sanity check
115115

@@ -215,11 +215,11 @@ def test_asarray_scalars(shape, data):
215215
else:
216216
_dtype = dtype
217217
if dh.is_float_dtype(_dtype):
218-
elements_strat = xps.from_dtype(_dtype) | xps.from_dtype(xp.int32)
218+
elements_strat = hh.from_dtype(_dtype) | hh.from_dtype(xp.int32)
219219
elif dh.is_int_dtype(_dtype):
220-
elements_strat = xps.from_dtype(_dtype) | st.booleans()
220+
elements_strat = hh.from_dtype(_dtype) | st.booleans()
221221
else:
222-
elements_strat = xps.from_dtype(_dtype)
222+
elements_strat = hh.from_dtype(_dtype)
223223
size = math.prod(shape)
224224
obj_strat = st.lists(elements_strat, min_size=size, max_size=size)
225225
scalar_type = dh.get_scalar_type(_dtype)
@@ -267,7 +267,7 @@ def scalar_eq(s1: Scalar, s2: Scalar) -> bool:
267267
data=st.data(),
268268
)
269269
def test_asarray_arrays(shape, dtypes, data):
270-
x = data.draw(xps.arrays(dtype=dtypes.input_dtype, shape=shape), label="x")
270+
x = data.draw(hh.arrays(dtype=dtypes.input_dtype, shape=shape), label="x")
271271
dtypes_strat = st.just(dtypes.input_dtype)
272272
if dtypes.input_dtype == dtypes.result_dtype:
273273
dtypes_strat |= st.none()
@@ -290,7 +290,7 @@ def test_asarray_arrays(shape, dtypes, data):
290290
stype = dh.get_scalar_type(x.dtype)
291291
idx = data.draw(xps.indices(x.shape, max_dims=0), label="mutating idx")
292292
old_value = stype(x[idx])
293-
scalar_strat = xps.from_dtype(dtypes.input_dtype).filter(
293+
scalar_strat = hh.from_dtype(dtypes.input_dtype).filter(
294294
lambda n: not scalar_eq(n, old_value)
295295
)
296296
value = data.draw(
@@ -326,7 +326,7 @@ def test_empty(shape, kw):
326326

327327

328328
@given(
329-
x=xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes()),
329+
x=hh.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes()),
330330
kw=hh.kwargs(dtype=st.none() | xps.scalar_dtypes()),
331331
)
332332
def test_empty_like(x, kw):
@@ -382,7 +382,7 @@ def full_fill_values(draw) -> Union[bool, int, float, complex]:
382382
st.shared(hh.kwargs(dtype=st.none() | xps.scalar_dtypes()), key="full_kw")
383383
)
384384
dtype = kw.get("dtype", None) or draw(default_safe_dtypes)
385-
return draw(xps.from_dtype(dtype))
385+
return draw(hh.from_dtype(dtype))
386386

387387

388388
@given(
@@ -430,8 +430,8 @@ def test_full(shape, fill_value, kw):
430430
@given(kw=hh.kwargs(dtype=st.none() | xps.scalar_dtypes()), data=st.data())
431431
def test_full_like(kw, data):
432432
dtype = kw.get("dtype", None) or data.draw(xps.scalar_dtypes(), label="dtype")
433-
x = data.draw(xps.arrays(dtype=dtype, shape=hh.shapes()), label="x")
434-
fill_value = data.draw(xps.from_dtype(dtype), label="fill_value")
433+
x = data.draw(hh.arrays(dtype=dtype, shape=hh.shapes()), label="x")
434+
fill_value = data.draw(hh.from_dtype(dtype), label="fill_value")
435435
out = xp.full_like(x, fill_value, **kw)
436436
dtype = kw.get("dtype", None) or x.dtype
437437
if kw.get("dtype", None) is None:
@@ -454,8 +454,8 @@ def test_full_like(kw, data):
454454
def test_linspace(num, dtype, endpoint, data):
455455
_dtype = dh.default_float if dtype is None else dtype
456456

457-
start = data.draw(xps.from_dtype(_dtype, **finite_kw), label="start")
458-
stop = data.draw(xps.from_dtype(_dtype, **finite_kw), label="stop")
457+
start = data.draw(hh.from_dtype(_dtype, **finite_kw), label="start")
458+
stop = data.draw(hh.from_dtype(_dtype, **finite_kw), label="stop")
459459
# avoid overflow errors
460460
assume(not xp.isnan(xp.asarray(stop - start, dtype=_dtype)))
461461
assume(not xp.isnan(xp.asarray(start - stop, dtype=_dtype)))
@@ -509,7 +509,7 @@ def test_meshgrid(dtype, data):
509509
)
510510
arrays = []
511511
for i, shape in enumerate(shapes, 1):
512-
x = data.draw(xps.arrays(dtype=dtype, shape=shape), label=f"x{i}")
512+
x = data.draw(hh.arrays(dtype=dtype, shape=shape), label=f"x{i}")
513513
arrays.append(x)
514514
# sanity check
515515
assert math.prod(math.prod(x.shape) for x in arrays) <= hh.MAX_ARRAY_SIZE
@@ -541,7 +541,7 @@ def test_ones(shape, kw):
541541

542542

543543
@given(
544-
x=xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes()),
544+
x=hh.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes()),
545545
kw=hh.kwargs(dtype=st.none() | xps.scalar_dtypes()),
546546
)
547547
def test_ones_like(x, kw):
@@ -579,7 +579,7 @@ def test_zeros(shape, kw):
579579

580580

581581
@given(
582-
x=xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes()),
582+
x=hh.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes()),
583583
kw=hh.kwargs(dtype=st.none() | xps.scalar_dtypes()),
584584
)
585585
def test_zeros_like(x, kw):

array_api_tests/test_data_type_functions.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def float32(n: Union[int, float]) -> float:
3434
)
3535
def test_astype(x_dtype, dtype, kw, data):
3636
if xp.bool in (x_dtype, dtype):
37-
elements_strat = xps.from_dtype(x_dtype)
37+
elements_strat = hh.from_dtype(x_dtype)
3838
else:
3939
m1, M1 = dh.dtype_ranges[x_dtype]
4040
m2, M2 = dh.dtype_ranges[dtype]
@@ -46,15 +46,15 @@ def test_astype(x_dtype, dtype, kw, data):
4646
cast = float
4747
min_value = cast(max(m1, m2))
4848
max_value = cast(min(M1, M2))
49-
elements_strat = xps.from_dtype(
49+
elements_strat = hh.from_dtype(
5050
x_dtype,
5151
min_value=min_value,
5252
max_value=max_value,
5353
allow_nan=False,
5454
allow_infinity=False,
5555
)
5656
x = data.draw(
57-
xps.arrays(dtype=x_dtype, shape=hh.shapes(), elements=elements_strat), label="x"
57+
hh.arrays(dtype=x_dtype, shape=hh.shapes(), elements=elements_strat), label="x"
5858
)
5959

6060
out = xp.astype(x, dtype, **kw)
@@ -71,7 +71,7 @@ def test_astype(x_dtype, dtype, kw, data):
7171
def test_broadcast_arrays(shapes, data):
7272
arrays = []
7373
for c, shape in enumerate(shapes, 1):
74-
x = data.draw(xps.arrays(dtype=xps.scalar_dtypes(), shape=shape), label=f"x{c}")
74+
x = data.draw(hh.arrays(dtype=xps.scalar_dtypes(), shape=shape), label=f"x{c}")
7575
arrays.append(x)
7676

7777
out = xp.broadcast_arrays(*arrays)
@@ -94,7 +94,7 @@ def test_broadcast_arrays(shapes, data):
9494
# TODO: test values
9595

9696

97-
@given(x=xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes()), data=st.data())
97+
@given(x=hh.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes()), data=st.data())
9898
def test_broadcast_to(x, data):
9999
shape = data.draw(
100100
hh.mutually_broadcastable_shapes(1, base_shape=x.shape)
@@ -113,7 +113,7 @@ def test_broadcast_to(x, data):
113113
@given(_from=non_complex_dtypes(), to=non_complex_dtypes(), data=st.data())
114114
def test_can_cast(_from, to, data):
115115
from_ = data.draw(
116-
st.just(_from) | xps.arrays(dtype=_from, shape=hh.shapes()), label="from_"
116+
st.just(_from) | hh.arrays(dtype=_from, shape=hh.shapes()), label="from_"
117117
)
118118

119119
out = xp.can_cast(from_, to)

0 commit comments

Comments
 (0)