Skip to content

Commit 4aa49d4

Browse files
committed
Assert inferred dtype correctly in test_full
1 parent 0fb851f commit 4aa49d4

File tree

1 file changed

+49
-30
lines changed

1 file changed

+49
-30
lines changed

array_api_tests/test_creation_functions.py

Lines changed: 49 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from ._array_module import (asarray, arange, ceil, empty, empty_like, eye, full,
22
full_like, equal, all, linspace, ones, ones_like,
3-
zeros, zeros_like, isnan, float32)
3+
zeros, zeros_like, isnan)
4+
from . import _array_module as xp
45
from .array_helpers import (is_integer_dtype, dtype_ranges,
56
assert_exactly_equal, isintegral, is_float_dtype)
67
from .hypothesis_helpers import (numeric_dtypes, dtypes, MAX_ARRAY_SIZE,
@@ -66,18 +67,17 @@ def test_arange(start, stop, step, dtype):
6667
or step < 0 and stop <= start)):
6768
assert a.size == ceil(asarray((stop-start)/step)), "arange() produced an array of the incorrect size"
6869

69-
@given(one_of(shapes, sizes), one_of(none(), dtypes))
70-
def test_empty(shape, dtype):
71-
if dtype is None:
72-
a = empty(shape)
73-
assert is_float_dtype(a.dtype), "empty() should returned an array with the default floating point dtype"
70+
@given(shapes, kwargs(dtype=none() | shared_dtypes))
71+
def test_empty(shape, kw):
72+
out = empty(shape, **kw)
73+
dtype = kw.get("dtype", None) or xp.float64
74+
if kw.get("dtype", None) is None:
75+
assert is_float_dtype(out.dtype), f"empty() returned an array with dtype {out.dtype}, but should be the default float dtype"
7476
else:
75-
a = empty(shape, dtype=dtype)
76-
assert a.dtype == dtype
77-
77+
assert out.dtype == dtype, f"{dtype=!s}, but empty() returned an array with dtype {out.dtype}"
7878
if isinstance(shape, int):
7979
shape = (shape,)
80-
assert a.shape == shape, "empty() produced an array with an incorrect shape"
80+
assert out.shape == shape, f"{shape=}, but empty() returned an array with shape {out.shape}"
8181

8282

8383
@given(
@@ -124,36 +124,55 @@ def test_eye(n_rows, n_cols, k, dtype):
124124
else:
125125
assert a[i, j] == 0, "eye() did not produce a 0 off the diagonal"
126126

127-
@given(shapes, scalars(shared_dtypes), one_of(none(), shared_dtypes))
128-
def test_full(shape, fill_value, dtype):
129-
kwargs = {} if dtype is None else {'dtype': dtype}
130127

131-
a = full(shape, fill_value, **kwargs)
128+
@composite
129+
def full_fill_values(draw):
130+
kw = draw(shared(kwargs(dtype=none() | xps.scalar_dtypes()), key="full_kw"))
131+
dtype = kw.get("dtype", None) or draw(xps.scalar_dtypes())
132+
return draw(xps.from_dtype(dtype))
132133

133-
if dtype is None:
134-
# TODO: Should it actually match the fill_value?
135-
# assert a.dtype in _floating_dtypes, "eye() should returned an array with the default floating point dtype"
136-
pass
137-
else:
138-
assert a.dtype == dtype
139134

140-
assert a.shape == shape, "full() produced an array with incorrect shape"
141-
if is_float_dtype(a.dtype) and isnan(asarray(fill_value)):
142-
assert all(isnan(a)), "full() array did not equal the fill value"
135+
@given(
136+
shape=shapes,
137+
fill_value=full_fill_values(),
138+
kw=shared(kwargs(dtype=none() | xps.scalar_dtypes()), key="full_kw"),
139+
)
140+
def test_full(shape, fill_value, kw):
141+
out = full(shape, fill_value, **kw)
142+
if kw.get("dtype", None):
143+
dtype = kw["dtype"]
144+
elif isinstance(fill_value, bool):
145+
dtype = xp.bool
146+
elif isinstance(fill_value, int):
147+
dtype = xp.int64
148+
else:
149+
dtype = xp.float64
150+
if kw.get("dtype", None) is None:
151+
if dtype == xp.float64:
152+
assert is_float_dtype(out.dtype), f"full() returned an array with dtype {out.dtype}, but should be the default float dtype"
153+
elif dtype == xp.int64:
154+
assert out.dtype == xp.int32 or out.dtype == xp.int64, f"full() returned an array with dtype {out.dtype}, but should be the default integer dtype"
155+
else:
156+
assert out.dtype == xp.bool, f"full() returned an array with dtype {out.dtype}, but should be the bool dtype"
157+
else:
158+
assert out.dtype == dtype
159+
assert out.shape == shape, f"{shape=}, but full() returned an array with shape {out.shape}"
160+
if is_float_dtype(out.dtype) and isnan(asarray(fill_value)):
161+
assert all(isnan(out)), "full() array did not equal the fill value"
143162
else:
144-
assert all(equal(a, asarray(fill_value, **kwargs))), "full() array did not equal the fill value"
163+
assert all(equal(out, asarray(fill_value, dtype=dtype))), "full() array did not equal the fill value"
145164

146165

147166
@composite
148-
def fill_values(draw):
167+
def full_like_fill_values(draw):
149168
kw = draw(shared(kwargs(dtype=none() | xps.scalar_dtypes()), key="full_like_kw"))
150169
dtype = kw.get("dtype", None) or draw(shared_dtypes)
151170
return draw(xps.from_dtype(dtype))
152171

153172

154173
@given(
155174
x=xps.arrays(dtype=shared_dtypes, shape=shapes),
156-
fill_value=fill_values(),
175+
fill_value=full_like_fill_values(),
157176
kw=shared(kwargs(dtype=none() | xps.scalar_dtypes()), key="full_like_kw"),
158177
)
159178
def test_full_like(x, fill_value, kw):
@@ -226,12 +245,12 @@ def make_one(dtype):
226245
@given(shapes, kwargs(dtype=none() | xps.scalar_dtypes()))
227246
def test_ones(shape, kw):
228247
out = ones(shape, **kw)
229-
dtype = kw.get("dtype", None) or float32
248+
dtype = kw.get("dtype", None) or xp.float64
230249
if kw.get("dtype", None) is None:
231-
assert is_float_dtype(out.dtype), "ones() returned an array with dtype {x.dtype}, but should be the default float dtype"
250+
assert is_float_dtype(out.dtype), f"ones() returned an array with dtype {out.dtype}, but should be the default float dtype"
232251
else:
233252
assert out.dtype == dtype, f"{dtype=!s}, but ones() returned an array with dtype {out.dtype}"
234-
assert out.shape == shape, "ones() produced an array with incorrect shape"
253+
assert out.shape == shape, f"{shape=}, but empty() returned an array with shape {out.shape}"
235254
assert all(equal(out, full((), make_one(dtype), dtype=dtype))), "ones() array did not equal 1"
236255

237256

@@ -262,7 +281,7 @@ def make_zero(dtype):
262281
@given(shapes, kwargs(dtype=none() | xps.scalar_dtypes()))
263282
def test_zeros(shape, kw):
264283
out = zeros(shape, **kw)
265-
dtype = kw.get("dtype", None) or float32
284+
dtype = kw.get("dtype", None) or xp.float64
266285
if kw.get("dtype", None) is None:
267286
assert is_float_dtype(out.dtype), "zeros() returned an array with dtype {out.dtype}, but should be the default float dtype"
268287
else:

0 commit comments

Comments
 (0)