Skip to content

Commit 2763d3e

Browse files
committed
Add tests for linspace()
1 parent c0ee777 commit 2763d3e

File tree

1 file changed

+57
-4
lines changed

1 file changed

+57
-4
lines changed

array_api_tests/test_creation_functions.py

Lines changed: 57 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
from ._array_module import (arange, ceil, empty, _floating_dtypes, eye, full,
2-
equal, all)
3-
from .array_helpers import is_integer_dtype, dtype_ranges
2+
equal, all, linspace)
3+
from .array_helpers import (is_integer_dtype, dtype_ranges,
4+
assert_exactly_equal, isintegral)
45
from .hypothesis_helpers import (numeric_dtypes, dtypes, MAX_ARRAY_SIZE,
56
shapes, sizes, sqrt_sizes, shared_dtypes,
67
shared_scalars)
78

89
from hypothesis import assume, given
9-
from hypothesis.strategies import integers, floats, one_of, none
10+
from hypothesis.strategies import integers, floats, one_of, none, booleans
1011

1112
int_range = integers(-MAX_ARRAY_SIZE, MAX_ARRAY_SIZE)
1213
float_range = floats(-MAX_ARRAY_SIZE, MAX_ARRAY_SIZE, allow_nan=False)
@@ -86,7 +87,7 @@ def test_eye(N, M, k, dtype):
8687
if dtype is None:
8788
assert a.dtype in _floating_dtypes, "eye() should produce an array with the default floating point dtype"
8889
else:
89-
assert a.dtype == dtype
90+
assert a.dtype == dtype, "eye() did not produce the correct dtype"
9091

9192
if M is None:
9293
M = N
@@ -116,3 +117,55 @@ def test_full(shape, fill_value, dtype):
116117

117118
assert a.shape == shape, "full() produced an array with incorrect shape"
118119
assert all(equal(a, fill_value)), "full() array did not equal the fill value"
120+
121+
# TODO: implement full_like (requires hypothesis arrays support)
122+
def test_full_like():
123+
pass
124+
125+
@given(one_of(integers(), floats(allow_nan=False, allow_infinity=False)),
126+
one_of(integers(), floats(allow_nan=False, allow_infinity=False)),
127+
sizes,
128+
one_of(none(), dtypes),
129+
one_of(none(), booleans()),)
130+
def test_linspace(start, stop, num, dtype, endpoint):
131+
if dtype in dtype_ranges:
132+
m, M = dtype_ranges[dtype]
133+
if (isinstance(start, int) and not (m <= start <= M)
134+
or isinstance(stop, int) and not (m <= stop <= M)):
135+
assume(False)
136+
# Skip on int start or stop that cannot be exactly represented as a float,
137+
# since we do not have good approx_equal helpers yet.
138+
if (dtype is None or dtype in _floating_dtypes
139+
and ((isinstance(start, int) and not isintegral(start))
140+
or (isinstance(stop, int) and not isintegral(stop)))):
141+
assume(False)
142+
143+
kwargs = {k: v for k, v in {'dtype': dtype, 'endpoint': endpoint}.items()
144+
if v is not None}
145+
a = linspace(start, stop, num, **kwargs)
146+
147+
if dtype is None:
148+
assert a.dtype in _floating_dtypes, "linspace() should produce an array with the default floating point dtype"
149+
else:
150+
assert a.dtype == dtype, "linspace() did not produce the correct dtype"
151+
152+
assert a.shape == (num,), "linspace() did not produce an array with the correct shape"
153+
154+
if endpoint in [None, True]:
155+
if num > 1:
156+
assert all(equal(a[-1], full((), stop, dtype=dtype))), "linspace() produced an array that does not the endpoint"
157+
else:
158+
# linspace(..., num, endpoint=False) is the same as the first num
159+
# elements of linspace(..., num+1, endpoint=True)
160+
b = linspace(start, stop, num + 1, **{**kwargs, 'endpoint': True})
161+
assert_exactly_equal(b[:-1], a)
162+
163+
if num > 0:
164+
# We need to cast start to dtype
165+
assert all(equal(a[0], full((), start, dtype=dtype))), "linspace() produced an array that does not start with the start"
166+
167+
# TODO: This requires an assert_approx_equal function
168+
169+
# n = num - 1 if endpoint in [None, True] else num
170+
# for i in range(1, num):
171+
# assert all(equal(a[i], full((), i*(stop - start)/n + start, dtype=dtype))), f"linspace() produced an array with an incorrect value at index {i}"

0 commit comments

Comments
 (0)