Skip to content

Commit 3506c00

Browse files
committed
Make new tests more in-line with established style
1 parent 2acbbc2 commit 3506c00

File tree

1 file changed

+22
-29
lines changed

1 file changed

+22
-29
lines changed

array_api_tests/test_creation_functions.py

Lines changed: 22 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -81,18 +81,17 @@ def test_empty(shape, dtype):
8181

8282
@given(
8383
a=xps.arrays(
84-
dtype=shared(xps.scalar_dtypes(), key='dtypes'),
84+
dtype=shared_dtypes,
8585
shape=xps.array_shapes(),
8686
),
87-
kwargs=one_of(
88-
just({}),
89-
shared(xps.scalar_dtypes(), key='dtypes').map(lambda d: {'dtype': d}),
90-
),
87+
dtype=one_of(none(), shared_dtypes),
9188
)
92-
def test_empty_like(a, kwargs):
89+
def test_empty_like(a, dtype):
90+
kwargs = {} if dtype is None else {'dtype': dtype}
91+
9392
a_like = empty_like(a, **kwargs)
9493

95-
if kwargs is None:
94+
if dtype is None:
9695
# TODO: Should it actually match a.dtype?
9796
# assert is_float_dtype(a_like.dtype), "empty_like() should produce an array with the default floating point dtype"
9897
pass
@@ -153,26 +152,24 @@ def test_full(shape, fill_value, dtype):
153152

154153
@given(
155154
a=xps.arrays(
156-
dtype=shared(xps.scalar_dtypes(), key='dtypes'),
155+
dtype=shared_dtypes,
157156
shape=xps.array_shapes(),
158157
),
159-
fill_value=shared(xps.scalar_dtypes(), key='dtypes').flatmap(xps.from_dtype),
160-
kwargs=one_of(
161-
just({}),
162-
shared(xps.scalar_dtypes(), key='dtypes').map(lambda d: {'dtype': d}),
163-
),
158+
fill_value=shared_dtypes.flatmap(xps.from_dtype),
159+
dtype=one_of(none(), shared_dtypes),
164160
)
165-
def test_full_like(a, fill_value, kwargs):
161+
def test_full_like(a, fill_value, dtype):
162+
kwargs = {} if dtype is None else {'dtype': dtype}
163+
166164
a_like = full_like(a, fill_value, **kwargs)
167165

168-
if kwargs is None:
166+
if dtype is None:
169167
# TODO: Should it actually match a.dtype?
170168
pass
171169
else:
172170
assert a_like.dtype == a.dtype
173171

174172
assert a_like.shape == a.shape, "full_like() produced an array with incorrect shape"
175-
176173
if is_float_dtype(a_like.dtype) and isnan(asarray(fill_value)):
177174
assert all(isnan(a_like)), "full_like() array did not equal the fill value"
178175
else:
@@ -247,15 +244,13 @@ def test_ones(shape, dtype):
247244

248245
@given(
249246
a=xps.arrays(
250-
dtype=shared(xps.scalar_dtypes(), key='dtypes'),
247+
dtype=shared_dtypes,
251248
shape=xps.array_shapes(),
252249
),
253-
kwargs=one_of(
254-
just({}),
255-
shared(xps.scalar_dtypes(), key='dtypes').map(lambda d: {'dtype': d}),
256-
),
250+
dtype=one_of(none(), shared_dtypes),
257251
)
258-
def test_ones_like(a, kwargs):
252+
def test_ones_like(a, dtype):
253+
kwargs = {} if dtype is None else {'dtype': dtype}
259254
if kwargs is None or is_float_dtype(a.dtype):
260255
ONE = 1.0
261256
elif is_integer_dtype(a.dtype):
@@ -300,16 +295,14 @@ def test_zeros(shape, dtype):
300295

301296
@given(
302297
a=xps.arrays(
303-
dtype=shared(xps.scalar_dtypes(), key='dtypes'),
298+
dtype=shared_dtypes,
304299
shape=xps.array_shapes(),
305300
),
306-
kwargs=one_of(
307-
just({}),
308-
shared(xps.scalar_dtypes(), key='dtypes').map(lambda d: {'dtype': d}),
309-
),
301+
dtype=one_of(none(), shared_dtypes),
310302
)
311-
def test_zeros_like(a, kwargs):
312-
if kwargs is None or is_float_dtype(a.dtype):
303+
def test_zeros_like(a, dtype):
304+
kwargs = {} if dtype is None else {'dtype': dtype}
305+
if dtype is None or is_float_dtype(a.dtype):
313306
ZERO = 0.0
314307
elif is_integer_dtype(a.dtype):
315308
ZERO = 0

0 commit comments

Comments
 (0)