|
1 | 1 | from ._array_module import (arange, ceil, empty, _floating_dtypes, eye, full,
|
2 |
| -equal, all) |
3 |
| -from .array_helpers import is_integer_dtype, dtype_ranges |
| 2 | +equal, all, linspace) |
| 3 | +from .array_helpers import (is_integer_dtype, dtype_ranges, |
| 4 | + assert_exactly_equal, isintegral) |
4 | 5 | from .hypothesis_helpers import (numeric_dtypes, dtypes, MAX_ARRAY_SIZE,
|
5 | 6 | shapes, sizes, sqrt_sizes, shared_dtypes,
|
6 | 7 | shared_scalars)
|
7 | 8 |
|
8 | 9 | from hypothesis import assume, given
|
9 |
| -from hypothesis.strategies import integers, floats, one_of, none |
| 10 | +from hypothesis.strategies import integers, floats, one_of, none, booleans |
10 | 11 |
|
11 | 12 | int_range = integers(-MAX_ARRAY_SIZE, MAX_ARRAY_SIZE)
|
12 | 13 | float_range = floats(-MAX_ARRAY_SIZE, MAX_ARRAY_SIZE, allow_nan=False)
|
@@ -86,7 +87,7 @@ def test_eye(N, M, k, dtype):
|
86 | 87 | if dtype is None:
|
87 | 88 | assert a.dtype in _floating_dtypes, "eye() should produce an array with the default floating point dtype"
|
88 | 89 | else:
|
89 |
| - assert a.dtype == dtype |
| 90 | + assert a.dtype == dtype, "eye() did not produce the correct dtype" |
90 | 91 |
|
91 | 92 | if M is None:
|
92 | 93 | M = N
|
@@ -116,3 +117,55 @@ def test_full(shape, fill_value, dtype):
|
116 | 117 |
|
117 | 118 | assert a.shape == shape, "full() produced an array with incorrect shape"
|
118 | 119 | assert all(equal(a, fill_value)), "full() array did not equal the fill value"
|
| 120 | + |
| 121 | +# TODO: implement full_like (requires hypothesis arrays support) |
| 122 | +def test_full_like(): |
| 123 | + pass |
| 124 | + |
| 125 | +@given(one_of(integers(), floats(allow_nan=False, allow_infinity=False)), |
| 126 | + one_of(integers(), floats(allow_nan=False, allow_infinity=False)), |
| 127 | + sizes, |
| 128 | + one_of(none(), dtypes), |
| 129 | + one_of(none(), booleans()),) |
| 130 | +def test_linspace(start, stop, num, dtype, endpoint): |
| 131 | + if dtype in dtype_ranges: |
| 132 | + m, M = dtype_ranges[dtype] |
| 133 | + if (isinstance(start, int) and not (m <= start <= M) |
| 134 | + or isinstance(stop, int) and not (m <= stop <= M)): |
| 135 | + assume(False) |
| 136 | + # Skip on int start or stop that cannot be exactly represented as a float, |
| 137 | + # since we do not have good approx_equal helpers yet. |
| 138 | + if (dtype is None or dtype in _floating_dtypes |
| 139 | + and ((isinstance(start, int) and not isintegral(start)) |
| 140 | + or (isinstance(stop, int) and not isintegral(stop)))): |
| 141 | + assume(False) |
| 142 | + |
| 143 | + kwargs = {k: v for k, v in {'dtype': dtype, 'endpoint': endpoint}.items() |
| 144 | + if v is not None} |
| 145 | + a = linspace(start, stop, num, **kwargs) |
| 146 | + |
| 147 | + if dtype is None: |
| 148 | + assert a.dtype in _floating_dtypes, "linspace() should produce an array with the default floating point dtype" |
| 149 | + else: |
| 150 | + assert a.dtype == dtype, "linspace() did not produce the correct dtype" |
| 151 | + |
| 152 | + assert a.shape == (num,), "linspace() did not produce an array with the correct shape" |
| 153 | + |
| 154 | + if endpoint in [None, True]: |
| 155 | + if num > 1: |
| 156 | + assert all(equal(a[-1], full((), stop, dtype=dtype))), "linspace() produced an array that does not the endpoint" |
| 157 | + else: |
| 158 | + # linspace(..., num, endpoint=False) is the same as the first num |
| 159 | + # elements of linspace(..., num+1, endpoint=True) |
| 160 | + b = linspace(start, stop, num + 1, **{**kwargs, 'endpoint': True}) |
| 161 | + assert_exactly_equal(b[:-1], a) |
| 162 | + |
| 163 | + if num > 0: |
| 164 | + # We need to cast start to dtype |
| 165 | + assert all(equal(a[0], full((), start, dtype=dtype))), "linspace() produced an array that does not start with the start" |
| 166 | + |
| 167 | + # TODO: This requires an assert_approx_equal function |
| 168 | + |
| 169 | + # n = num - 1 if endpoint in [None, True] else num |
| 170 | + # for i in range(1, num): |
| 171 | + # assert all(equal(a[i], full((), i*(stop - start)/n + start, dtype=dtype))), f"linspace() produced an array with an incorrect value at index {i}" |
0 commit comments