Skip to content

Commit ff45b23

Browse files
committed
Add a test for full()
1 parent 8eb0873 commit ff45b23

File tree

3 files changed

+44
-5
lines changed

3 files changed

+44
-5
lines changed

array_api_tests/_array_module.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def __repr__(self):
5454
__getattr__ = _raise
5555

5656
_integer_dtypes = [
57-
'int8',
57+
'int8',
5858
'int16',
5959
'int32',
6060
'int64',

array_api_tests/hypothesis_helpers.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,15 @@
44

55
from hypothesis.strategies import (lists, integers, builds, sampled_from,
66
shared, tuples as hypotheses_tuples,
7-
floats, just, composite, one_of, none)
7+
floats, just, composite, one_of, none,
8+
booleans)
89
from hypothesis import assume
910

1011
from .pytest_helpers import nargs
12+
from .array_helpers import dtype_ranges
1113
from ._array_module import (_integer_dtypes, _floating_dtypes,
12-
_numeric_dtypes, _dtypes, ones, full)
14+
_numeric_dtypes, _dtypes, ones, full, float32,
15+
float64, bool as bool_dtype)
1316
from . import _array_module
1417

1518
from .function_stubs import elementwise_functions
@@ -73,6 +76,24 @@ def tuples(elements, *, min_size=0, max_size=None, unique_by=None, unique=False)
7376
# TODO: Generate general arrays here, rather than just scalars.
7477
numeric_arrays = builds(full, just((1,)), floats())
7578

79+
@composite
80+
def shared_scalars(draw):
81+
"""
82+
Strategy to generate a scalar that matches the dtype from shared_dtypes
83+
"""
84+
dtype = draw(shared_dtypes)
85+
if dtype in dtype_ranges:
86+
m, M = dtype_ranges[dtype]
87+
return draw(integers(m, M))
88+
elif dtype == bool_dtype:
89+
return draw(booleans())
90+
elif dtype == float64:
91+
return draw(floats())
92+
elif dtype == float32:
93+
return draw(floats(width=32))
94+
else:
95+
raise ValueError(f"Unrecognized dtype {dtype}")
96+
7697
@composite
7798
def integer_indices(draw, sizes):
7899
size = draw(sizes)

array_api_tests/test_creation_functions.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1-
from ._array_module import arange, ceil, empty, _floating_dtypes, eye
1+
from ._array_module import (arange, ceil, empty, _floating_dtypes, eye, full,
2+
equal, all)
23
from .array_helpers import is_integer_dtype, dtype_ranges
34
from .hypothesis_helpers import (numeric_dtypes, dtypes, MAX_ARRAY_SIZE,
4-
shapes, sizes, sqrt_sizes)
5+
shapes, sizes, sqrt_sizes, shared_dtypes,
6+
shared_scalars)
57

68
from hypothesis import assume, given
79
from hypothesis.strategies import integers, floats, one_of, none
@@ -98,3 +100,19 @@ def test_eye(N, M, k, dtype):
98100
assert a[i, j] == 1, "eye() did not produce a 1 on the diagonal"
99101
else:
100102
assert a[i, j] == 0, "eye() did not produce a 0 off the diagonal"
103+
104+
@given(shapes, shared_scalars(), one_of(none(), shared_dtypes))
105+
def test_full(shape, fill_value, dtype):
106+
kwargs = {} if dtype is None else {'dtype': dtype}
107+
108+
a = full(shape, fill_value, **kwargs)
109+
110+
if dtype is None:
111+
# TODO: Should it actually match the fill_value?
112+
# assert a.dtype in _floating_dtypes, "eye() should produce an array with the default floating point dtype"
113+
pass
114+
else:
115+
assert a.dtype == dtype
116+
117+
assert a.shape == shape, "full() produced an array with incorrect shape"
118+
assert all(equal(a, fill_value)), "full() array did not equal the fill value"

0 commit comments

Comments
 (0)