Skip to content

Commit 974367f

Browse files
authored
Merge pull request #226 from honno/jax-niceties
Don't test values that are on/near the boundary of an array's dtype
2 parents 7c89cf1 + 4d22300 commit 974367f

16 files changed

+226
-204
lines changed

array_api_tests/hypothesis_helpers.py

Lines changed: 76 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,16 @@
11
import re
22
import itertools
33
from contextlib import contextmanager
4-
from functools import reduce
5-
from math import sqrt
4+
from functools import reduce, wraps
5+
import math
66
from operator import mul
7-
from typing import Any, List, NamedTuple, Optional, Sequence, Tuple, Union
7+
import struct
8+
from typing import Any, List, Mapping, NamedTuple, Optional, Sequence, Tuple, Union
89

910
from hypothesis import assume, reject
1011
from hypothesis.strategies import (SearchStrategy, booleans, composite, floats,
1112
integers, just, lists, none, one_of,
12-
sampled_from, shared)
13+
sampled_from, shared, builds)
1314

1415
from . import _array_module as xp, api_version
1516
from . import dtype_helpers as dh
@@ -20,26 +21,60 @@
2021
from ._array_module import broadcast_to, eye, float32, float64, full
2122
from .stubs import category_to_funcs
2223
from .pytest_helpers import nargs
23-
from .typing import Array, DataType, Shape
24-
25-
# Set this to True to not fail tests just because a dtype isn't implemented.
26-
# If no compatible dtype is implemented for a given test, the test will fail
27-
# with a hypothesis health check error. Note that this functionality will not
28-
# work for floating point dtypes as those are assumed to be defined in other
29-
# places in the tests.
30-
FILTER_UNDEFINED_DTYPES = True
31-
# TODO: currently we assume this to be true - we probably can remove this completely
32-
assert FILTER_UNDEFINED_DTYPES
33-
34-
integer_dtypes = xps.integer_dtypes() | xps.unsigned_integer_dtypes()
35-
floating_dtypes = xps.floating_dtypes()
36-
numeric_dtypes = xps.numeric_dtypes()
37-
integer_or_boolean_dtypes = xps.boolean_dtypes() | integer_dtypes
38-
boolean_dtypes = xps.boolean_dtypes()
39-
dtypes = xps.scalar_dtypes()
40-
41-
shared_dtypes = shared(dtypes, key="dtype")
42-
shared_floating_dtypes = shared(floating_dtypes, key="dtype")
24+
from .typing import Array, DataType, Scalar, Shape
25+
26+
27+
def _float32ify(n: Union[int, float]) -> float:
28+
n = float(n)
29+
return struct.unpack("!f", struct.pack("!f", n))[0]
30+
31+
32+
@wraps(xps.from_dtype)
33+
def from_dtype(dtype, **kwargs) -> SearchStrategy[Scalar]:
34+
"""xps.from_dtype() without the crazy large numbers."""
35+
if dtype == xp.bool:
36+
return xps.from_dtype(dtype, **kwargs)
37+
38+
if dtype in dh.complex_dtypes:
39+
component_dtype = dh.dtype_components[dtype]
40+
else:
41+
component_dtype = dtype
42+
43+
min_, max_ = dh.dtype_ranges[component_dtype]
44+
45+
if "min_value" not in kwargs.keys() and min_ != 0:
46+
assert min_ < 0 # sanity check
47+
min_value = -1 * math.floor(math.sqrt(abs(min_)))
48+
if component_dtype == xp.float32:
49+
min_value = _float32ify(min_value)
50+
kwargs["min_value"] = min_value
51+
if "max_value" not in kwargs.keys():
52+
assert max_ > 0 # sanity check
53+
max_value = math.floor(math.sqrt(max_))
54+
if component_dtype == xp.float32:
55+
max_value = _float32ify(max_value)
56+
kwargs["max_value"] = max_value
57+
58+
if dtype in dh.complex_dtypes:
59+
component_strat = xps.from_dtype(dh.dtype_components[dtype], **kwargs)
60+
return builds(complex, component_strat, component_strat)
61+
else:
62+
return xps.from_dtype(dtype, **kwargs)
63+
64+
65+
@wraps(xps.arrays)
66+
def arrays(dtype, *args, elements=None, **kwargs) -> SearchStrategy[Array]:
67+
"""xps.arrays() without the crazy large numbers."""
68+
if isinstance(dtype, SearchStrategy):
69+
return dtype.flatmap(lambda d: arrays(d, *args, elements=elements, **kwargs))
70+
71+
if elements is None:
72+
elements = from_dtype(dtype)
73+
elif isinstance(elements, Mapping):
74+
elements = from_dtype(dtype, **elements)
75+
76+
return xps.arrays(dtype, *args, elements=elements, **kwargs)
77+
4378

4479
_dtype_categories = [(xp.bool,), dh.uint_dtypes, dh.int_dtypes, dh.real_float_dtypes, dh.complex_dtypes]
4580
_sorted_dtypes = [d for category in _dtype_categories for d in category]
@@ -62,21 +97,19 @@ def _dtypes_sorter(dtype_pair: Tuple[DataType, DataType]):
6297
return key
6398

6499
_promotable_dtypes = list(dh.promotion_table.keys())
65-
if FILTER_UNDEFINED_DTYPES:
66-
_promotable_dtypes = [
67-
(d1, d2) for d1, d2 in _promotable_dtypes
68-
if not isinstance(d1, _UndefinedStub) or not isinstance(d2, _UndefinedStub)
69-
]
100+
_promotable_dtypes = [
101+
(d1, d2) for d1, d2 in _promotable_dtypes
102+
if not isinstance(d1, _UndefinedStub) or not isinstance(d2, _UndefinedStub)
103+
]
70104
promotable_dtypes: List[Tuple[DataType, DataType]] = sorted(_promotable_dtypes, key=_dtypes_sorter)
71105

72106
def mutually_promotable_dtypes(
73107
max_size: Optional[int] = 2,
74108
*,
75109
dtypes: Sequence[DataType] = dh.all_dtypes,
76110
) -> SearchStrategy[Tuple[DataType, ...]]:
77-
if FILTER_UNDEFINED_DTYPES:
78-
dtypes = [d for d in dtypes if not isinstance(d, _UndefinedStub)]
79-
assert len(dtypes) > 0, "all dtypes undefined" # sanity check
111+
dtypes = [d for d in dtypes if not isinstance(d, _UndefinedStub)]
112+
assert len(dtypes) > 0, "all dtypes undefined" # sanity check
80113
if max_size == 2:
81114
return sampled_from(
82115
[(i, j) for i, j in promotable_dtypes if i in dtypes and j in dtypes]
@@ -166,7 +199,7 @@ def all_floating_dtypes() -> SearchStrategy[DataType]:
166199
# Limit the total size of an array shape
167200
MAX_ARRAY_SIZE = 10000
168201
# Size to use for 2-dim arrays
169-
SQRT_MAX_ARRAY_SIZE = int(sqrt(MAX_ARRAY_SIZE))
202+
SQRT_MAX_ARRAY_SIZE = int(math.sqrt(MAX_ARRAY_SIZE))
170203

171204
# np.prod and others have overflow and math.prod is Python 3.8+ only
172205
def prod(seq):
@@ -202,7 +235,7 @@ def matrix_shapes(draw, stack_shapes=shapes()):
202235

203236
@composite
204237
def finite_matrices(draw, shape=matrix_shapes()):
205-
return draw(xps.arrays(dtype=xps.floating_dtypes(),
238+
return draw(arrays(dtype=xps.floating_dtypes(),
206239
shape=shape,
207240
elements=dict(allow_nan=False,
208241
allow_infinity=False)))
@@ -211,7 +244,7 @@ def finite_matrices(draw, shape=matrix_shapes()):
211244
# Should we set a max_value here?
212245
_rtol_float_kw = dict(allow_nan=False, allow_infinity=False, min_value=0)
213246
rtols = one_of(floats(**_rtol_float_kw),
214-
xps.arrays(dtype=xps.floating_dtypes(),
247+
arrays(dtype=xps.floating_dtypes(),
215248
shape=rtol_shared_matrix_shapes.map(lambda shape: shape[:-2]),
216249
elements=_rtol_float_kw))
217250

@@ -254,7 +287,7 @@ def symmetric_matrices(draw, dtypes=xps.floating_dtypes(), finite=True):
254287
if not isinstance(finite, bool):
255288
finite = draw(finite)
256289
elements = {'allow_nan': False, 'allow_infinity': False} if finite else None
257-
a = draw(xps.arrays(dtype=dtype, shape=shape, elements=elements))
290+
a = draw(arrays(dtype=dtype, shape=shape, elements=elements))
258291
upper = xp.triu(a)
259292
lower = xp.triu(a, k=1).mT
260293
return upper + lower
@@ -277,7 +310,7 @@ def invertible_matrices(draw, dtypes=xps.floating_dtypes(), stack_shapes=shapes(
277310
n = draw(integers(0, SQRT_MAX_ARRAY_SIZE),)
278311
stack_shape = draw(stack_shapes)
279312
shape = stack_shape + (n, n)
280-
d = draw(xps.arrays(dtypes, shape=n*prod(stack_shape),
313+
d = draw(arrays(dtypes, shape=n*prod(stack_shape),
281314
elements=dict(allow_nan=False, allow_infinity=False)))
282315
# Functions that require invertible matrices may do anything when it is
283316
# singular, including raising an exception, so we make sure the diagonals
@@ -303,7 +336,7 @@ def two_broadcastable_shapes(draw):
303336
sizes = integers(0, MAX_ARRAY_SIZE)
304337
sqrt_sizes = integers(0, SQRT_MAX_ARRAY_SIZE)
305338

306-
numeric_arrays = xps.arrays(
339+
numeric_arrays = arrays(
307340
dtype=shared(xps.floating_dtypes(), key='dtypes'),
308341
shape=shared(xps.array_shapes(), key='shapes'),
309342
)
@@ -348,7 +381,7 @@ def python_integer_indices(draw, sizes):
348381
def integer_indices(draw, sizes):
349382
# Return either a Python integer or a 0-D array with some integer dtype
350383
idx = draw(python_integer_indices(sizes))
351-
dtype = draw(integer_dtypes)
384+
dtype = draw(xps.integer_dtypes() | xps.unsigned_integer_dtypes())
352385
m, M = dh.dtype_ranges[dtype]
353386
if m <= idx <= M:
354387
return draw(one_of(just(idx),
@@ -424,16 +457,15 @@ def two_mutual_arrays(
424457
) -> Tuple[SearchStrategy[Array], SearchStrategy[Array]]:
425458
if not isinstance(dtypes, Sequence):
426459
raise TypeError(f"{dtypes=} not a sequence")
427-
if FILTER_UNDEFINED_DTYPES:
428-
dtypes = [d for d in dtypes if not isinstance(d, _UndefinedStub)]
429-
assert len(dtypes) > 0 # sanity check
460+
dtypes = [d for d in dtypes if not isinstance(d, _UndefinedStub)]
461+
assert len(dtypes) > 0 # sanity check
430462
mutual_dtypes = shared(mutually_promotable_dtypes(dtypes=dtypes))
431463
mutual_shapes = shared(two_shapes)
432-
arrays1 = xps.arrays(
464+
arrays1 = arrays(
433465
dtype=mutual_dtypes.map(lambda pair: pair[0]),
434466
shape=mutual_shapes.map(lambda pair: pair[0]),
435467
)
436-
arrays2 = xps.arrays(
468+
arrays2 = arrays(
437469
dtype=mutual_dtypes.map(lambda pair: pair[1]),
438470
shape=mutual_shapes.map(lambda pair: pair[1]),
439471
)

array_api_tests/meta/test_hypothesis_helpers.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -128,20 +128,20 @@ def run(n, d, data):
128128
assert any("d" in kw.keys() and kw["d"] is xp.float64 for kw in results)
129129

130130

131-
132131
@given(finite=st.booleans(), dtype=xps.floating_dtypes(), data=st.data())
133132
def test_symmetric_matrices(finite, dtype, data):
134-
m = data.draw(hh.symmetric_matrices(st.just(dtype), finite=finite))
133+
m = data.draw(hh.symmetric_matrices(st.just(dtype), finite=finite), label="m")
135134
assert m.dtype == dtype
136135
# TODO: This part of this test should be part of the .mT test
137136
ah.assert_exactly_equal(m, m.mT)
138137

139138
if finite:
140139
ah.assert_finite(m)
141140

142-
@given(m=hh.positive_definite_matrices(hh.shared_floating_dtypes),
143-
dtype=hh.shared_floating_dtypes)
144-
def test_positive_definite_matrices(m, dtype):
141+
142+
@given(dtype=xps.floating_dtypes(), data=st.data())
143+
def test_positive_definite_matrices(dtype, data):
144+
m = data.draw(hh.positive_definite_matrices(st.just(dtype)), label="m")
145145
assert m.dtype == dtype
146146
# TODO: Test that it actually is positive definite
147147

array_api_tests/test_array_object.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def scalar_objects(
2424
) -> st.SearchStrategy[Union[Scalar, List[Scalar]]]:
2525
"""Generates scalars or nested sequences which are valid for xp.asarray()"""
2626
size = math.prod(shape)
27-
return st.lists(xps.from_dtype(dtype), min_size=size, max_size=size).map(
27+
return st.lists(hh.from_dtype(dtype), min_size=size, max_size=size).map(
2828
lambda l: sh.reshape(l, shape)
2929
)
3030

@@ -123,10 +123,10 @@ def test_setitem(shape, dtypes, data):
123123
key = data.draw(xps.indices(shape=shape), label="key")
124124
_key = normalise_key(key, shape)
125125
axes_indices, out_shape = get_indexed_axes_and_out_shape(_key, shape)
126-
value_strat = xps.arrays(dtype=dtypes.result_dtype, shape=out_shape)
126+
value_strat = hh.arrays(dtype=dtypes.result_dtype, shape=out_shape)
127127
if out_shape == ():
128128
# We can pass scalars if we're only indexing one element
129-
value_strat |= xps.from_dtype(dtypes.result_dtype)
129+
value_strat |= hh.from_dtype(dtypes.result_dtype)
130130
value = data.draw(value_strat, label="value")
131131

132132
res = xp.asarray(x, copy=True)
@@ -157,15 +157,15 @@ def test_setitem(shape, dtypes, data):
157157
@pytest.mark.data_dependent_shapes
158158
@given(hh.shapes(), st.data())
159159
def test_getitem_masking(shape, data):
160-
x = data.draw(xps.arrays(xps.scalar_dtypes(), shape=shape), label="x")
160+
x = data.draw(hh.arrays(xps.scalar_dtypes(), shape=shape), label="x")
161161
mask_shapes = st.one_of(
162162
st.sampled_from([x.shape, ()]),
163163
st.lists(st.booleans(), min_size=x.ndim, max_size=x.ndim).map(
164164
lambda l: tuple(s if b else 0 for s, b in zip(x.shape, l))
165165
),
166166
hh.shapes(),
167167
)
168-
key = data.draw(xps.arrays(dtype=xp.bool, shape=mask_shapes), label="key")
168+
key = data.draw(hh.arrays(dtype=xp.bool, shape=mask_shapes), label="key")
169169

170170
if key.ndim > x.ndim or not all(
171171
ks in (xs, 0) for xs, ks in zip(x.shape, key.shape)
@@ -201,10 +201,10 @@ def test_getitem_masking(shape, data):
201201

202202
@given(hh.shapes(), st.data())
203203
def test_setitem_masking(shape, data):
204-
x = data.draw(xps.arrays(xps.scalar_dtypes(), shape=shape), label="x")
205-
key = data.draw(xps.arrays(dtype=xp.bool, shape=shape), label="key")
204+
x = data.draw(hh.arrays(xps.scalar_dtypes(), shape=shape), label="x")
205+
key = data.draw(hh.arrays(dtype=xp.bool, shape=shape), label="key")
206206
value = data.draw(
207-
xps.from_dtype(x.dtype) | xps.arrays(dtype=x.dtype, shape=()), label="value"
207+
hh.from_dtype(x.dtype) | hh.arrays(dtype=x.dtype, shape=()), label="value"
208208
)
209209

210210
res = xp.asarray(x, copy=True)
@@ -263,7 +263,7 @@ def test_scalar_casting(method_name, dtype_name, stype, data):
263263
dtype = getattr(_xp, dtype_name)
264264
except AttributeError as e:
265265
pytest.skip(str(e))
266-
x = data.draw(xps.arrays(dtype, shape=()), label="x")
266+
x = data.draw(hh.arrays(dtype, shape=()), label="x")
267267
method = getattr(x, method_name)
268268
out = method()
269269
assert isinstance(

0 commit comments

Comments
 (0)