Skip to content

Commit 1cf1747

Browse files
committed
Add tests for ones() and zeros()
1 parent 2763d3e commit 1cf1747

File tree

1 file changed

+43
-1
lines changed

1 file changed

+43
-1
lines changed

array_api_tests/test_creation_functions.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from ._array_module import (arange, ceil, empty, _floating_dtypes, eye, full,
2-
equal, all, linspace)
2+
equal, all, linspace, ones, zeros)
33
from .array_helpers import (is_integer_dtype, dtype_ranges,
44
assert_exactly_equal, isintegral)
55
from .hypothesis_helpers import (numeric_dtypes, dtypes, MAX_ARRAY_SIZE,
@@ -169,3 +169,45 @@ def test_linspace(start, stop, num, dtype, endpoint):
169169
# n = num - 1 if endpoint in [None, True] else num
170170
# for i in range(1, num):
171171
# assert all(equal(a[i], full((), i*(stop - start)/n + start, dtype=dtype))), f"linspace() produced an array with an incorrect value at index {i}"
172+
173+
@given(shapes, one_of(none(), dtypes))
174+
def test_ones(shape, dtype):
175+
kwargs = {} if dtype is None else {'dtype': dtype}
176+
177+
a = ones(shape, **kwargs)
178+
179+
if dtype is None:
180+
# TODO: Should it actually match the fill_value?
181+
# assert a.dtype in _floating_dtypes, "eye() should produce an array with the default floating point dtype"
182+
pass
183+
else:
184+
assert a.dtype == dtype
185+
186+
assert a.shape == shape, "ones() produced an array with incorrect shape"
187+
assert all(equal(a, full((), 1, **kwargs))), "ones() array did not equal 1"
188+
189+
# TODO: implement ones_like (requires hypothesis arrays support)
190+
def test_ones_like():
191+
pass
192+
193+
194+
195+
@given(shapes, one_of(none(), dtypes))
196+
def test_zeros(shape, dtype):
197+
kwargs = {} if dtype is None else {'dtype': dtype}
198+
199+
a = zeros(shape, **kwargs)
200+
201+
if dtype is None:
202+
# TODO: Should it actually match the fill_value?
203+
# assert a.dtype in _floating_dtypes, "eye() should produce an array with the default floating point dtype"
204+
pass
205+
else:
206+
assert a.dtype == dtype
207+
208+
assert a.shape == shape, "zeros() produced an array with incorrect shape"
209+
assert all(equal(a, full((), 0, **kwargs))), "zeros() array did not equal 0"
210+
211+
# TODO: implement zeros_like (requires hypothesis arrays support)
212+
def test_zeros_like():
213+
pass

0 commit comments

Comments
 (0)