Skip to content

Commit fc1e7f6

Browse files
committed
Rudimentary empty/ones/zeros-like tests
1 parent d42c9d2 commit fc1e7f6

File tree

1 file changed

+88
-12
lines changed

1 file changed

+88
-12
lines changed

array_api_tests/test_creation_functions.py

Lines changed: 88 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
from ._array_module import (asarray, arange, ceil, empty, eye, full, full_like,
2-
equal, all, linspace, ones, zeros, isnan)
1+
from ._array_module import (asarray, arange, ceil, empty, empty_like, eye, full,
2+
full_like, equal, all, linspace, ones, ones_like,
3+
zeros, zeros_like, isnan)
34
from .array_helpers import (is_integer_dtype, dtype_ranges,
45
assert_exactly_equal, isintegral, is_float_dtype)
56
from .hypothesis_helpers import (numeric_dtypes, dtypes, MAX_ARRAY_SIZE,
@@ -77,9 +78,29 @@ def test_empty(shape, dtype):
7778
shape = (shape,)
7879
assert a.shape == shape, "empty() produced an array with an incorrect shape"
7980

80-
# TODO: implement empty_like (requires hypothesis arrays support)
81-
def test_empty_like():
82-
pass
81+
82+
@given(
83+
a=xps.arrays(
84+
dtype=shared(xps.scalar_dtypes(), key='dtypes'),
85+
shape=xps.array_shapes(),
86+
),
87+
kwargs=one_of(
88+
just({}),
89+
shared(xps.scalar_dtypes(), key='dtypes').map(lambda d: {'dtype': d}),
90+
),
91+
)
92+
def test_empty_like(a, kwargs):
93+
a_like = empty_like(a, **kwargs)
94+
95+
if kwargs is None:
96+
# TODO: Should it actually match a.dtype?
97+
# assert is_float_dtype(a_like.dtype), "empty_like() should produce an array with the default floating point dtype"
98+
pass
99+
else:
100+
assert a_like.dtype == a.dtype, "empty_like() produced an array with an incorrect dtype"
101+
102+
assert a_like.shape == a.shape, "empty_like() produced an array with an incorrect shape"
103+
83104

84105
# TODO: Use this method for all optional arguments
85106
optional_marker = object()
@@ -145,7 +166,8 @@ def test_full_like(a, fill_value, kwargs):
145166
a_like = full_like(a, fill_value, **kwargs)
146167

147168
if kwargs is None:
148-
pass # TODO: Should it actually match the fill_value?
169+
# TODO: Should it actually match a.dtype?
170+
pass
149171
else:
150172
assert a_like.dtype == a.dtype
151173

@@ -222,9 +244,36 @@ def test_ones(shape, dtype):
222244
assert a.shape == shape, "ones() produced an array with incorrect shape"
223245
assert all(equal(a, full((), ONE, **kwargs))), "ones() array did not equal 1"
224246

225-
# TODO: implement ones_like (requires hypothesis arrays support)
226-
def test_ones_like():
227-
pass
247+
248+
@given(
249+
a=xps.arrays(
250+
dtype=shared(xps.scalar_dtypes(), key='dtypes'),
251+
shape=xps.array_shapes(),
252+
),
253+
kwargs=one_of(
254+
just({}),
255+
shared(xps.scalar_dtypes(), key='dtypes').map(lambda d: {'dtype': d}),
256+
),
257+
)
258+
def test_ones_like(a, kwargs):
259+
if kwargs is None or is_float_dtype(a.dtype):
260+
ONE = 1.0
261+
elif is_integer_dtype(a.dtype):
262+
ONE = 1
263+
else:
264+
ONE = True
265+
266+
a_like = ones_like(a, **kwargs)
267+
268+
if kwargs is None:
269+
# TODO: Should it actually match a.dtype?
270+
pass
271+
else:
272+
assert a_like.dtype == a.dtype, "ones_like() produced an array with an incorrect dtype"
273+
274+
assert a_like.shape == a.shape, "ones_like() produced an array with an incorrect shape"
275+
assert all(equal(a_like, full((), ONE, dtype=a_like.dtype))), "ones_like() array did not equal 1"
276+
228277

229278
@given(shapes, one_of(none(), dtypes))
230279
def test_zeros(shape, dtype):
@@ -248,6 +297,33 @@ def test_zeros(shape, dtype):
248297
assert a.shape == shape, "zeros() produced an array with incorrect shape"
249298
assert all(equal(a, full((), ZERO, **kwargs))), "zeros() array did not equal 0"
250299

251-
# TODO: implement zeros_like (requires hypothesis arrays support)
252-
def test_zeros_like():
253-
pass
300+
301+
@given(
302+
a=xps.arrays(
303+
dtype=shared(xps.scalar_dtypes(), key='dtypes'),
304+
shape=xps.array_shapes(),
305+
),
306+
kwargs=one_of(
307+
just({}),
308+
shared(xps.scalar_dtypes(), key='dtypes').map(lambda d: {'dtype': d}),
309+
),
310+
)
311+
def test_zeros_like(a, kwargs):
312+
if kwargs is None or is_float_dtype(a.dtype):
313+
ZERO = 0.0
314+
elif is_integer_dtype(a.dtype):
315+
ZERO = 0
316+
else:
317+
ZERO = False
318+
319+
a_like = zeros_like(a, **kwargs)
320+
321+
if kwargs is None:
322+
# TODO: Should it actually match a.dtype?
323+
pass
324+
else:
325+
assert a_like.dtype == a.dtype, "zeros_like() produced an array with an incorrect dtype"
326+
327+
assert a_like.shape == a.shape, "zeros_like() produced an array with an incorrect shape"
328+
assert all(equal(a_like, full((), ZERO, dtype=a_like.dtype))), "zeros_like() array did not equal 0"
329+

0 commit comments

Comments
 (0)