Skip to content

Commit 73c47d8

Browse files
committed
Remove redundant dtype strategies in hypothesis_helpers.py
1 parent e901687 commit 73c47d8

File tree

3 files changed

+13
-33
lines changed

3 files changed

+13
-33
lines changed

array_api_tests/hypothesis_helpers.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -22,16 +22,6 @@
2222
from .pytest_helpers import nargs
2323
from .typing import Array, DataType, Shape
2424

25-
integer_dtypes = xps.integer_dtypes() | xps.unsigned_integer_dtypes()
26-
floating_dtypes = xps.floating_dtypes()
27-
numeric_dtypes = xps.numeric_dtypes()
28-
integer_or_boolean_dtypes = xps.boolean_dtypes() | integer_dtypes
29-
boolean_dtypes = xps.boolean_dtypes()
30-
dtypes = xps.scalar_dtypes()
31-
32-
shared_dtypes = shared(dtypes, key="dtype")
33-
shared_floating_dtypes = shared(floating_dtypes, key="dtype")
34-
3525
_dtype_categories = [(xp.bool,), dh.uint_dtypes, dh.int_dtypes, dh.real_float_dtypes, dh.complex_dtypes]
3626
_sorted_dtypes = [d for category in _dtype_categories for d in category]
3727

@@ -337,7 +327,7 @@ def python_integer_indices(draw, sizes):
337327
def integer_indices(draw, sizes):
338328
# Return either a Python integer or a 0-D array with some integer dtype
339329
idx = draw(python_integer_indices(sizes))
340-
dtype = draw(integer_dtypes)
330+
dtype = draw(xps.integer_dtypes() | xps.unsigned_integer_dtypes())
341331
m, M = dh.dtype_ranges[dtype]
342332
if m <= idx <= M:
343333
return draw(one_of(just(idx),

array_api_tests/meta/test_hypothesis_helpers.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -128,20 +128,20 @@ def run(n, d, data):
128128
assert any("d" in kw.keys() and kw["d"] is xp.float64 for kw in results)
129129

130130

131-
132131
@given(finite=st.booleans(), dtype=xps.floating_dtypes(), data=st.data())
133132
def test_symmetric_matrices(finite, dtype, data):
134-
m = data.draw(hh.symmetric_matrices(st.just(dtype), finite=finite))
133+
m = data.draw(hh.symmetric_matrices(st.just(dtype), finite=finite), label="m")
135134
assert m.dtype == dtype
136135
# TODO: This part of this test should be part of the .mT test
137136
ah.assert_exactly_equal(m, m.mT)
138137

139138
if finite:
140139
ah.assert_finite(m)
141140

142-
@given(m=hh.positive_definite_matrices(hh.shared_floating_dtypes),
143-
dtype=hh.shared_floating_dtypes)
144-
def test_positive_definite_matrices(m, dtype):
141+
142+
@given(dtype=xps.floating_dtypes(), data=st.data())
143+
def test_positive_definite_matrices(dtype, data):
144+
m = data.draw(hh.positive_definite_matrices(st.just(dtype)), label="m")
145145
assert m.dtype == dtype
146146
# TODO: Test that it actually is positive definite
147147

array_api_tests/test_creation_functions.py

Lines changed: 7 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -427,21 +427,11 @@ def test_full(shape, fill_value, kw):
427427
ph.assert_fill("full", fill_value=fill_value, dtype=dtype, out=out, kw=dict(fill_value=fill_value))
428428

429429

430-
@st.composite
431-
def full_like_fill_values(draw):
432-
kw = draw(
433-
st.shared(hh.kwargs(dtype=st.none() | xps.scalar_dtypes()), key="full_like_kw")
434-
)
435-
dtype = kw.get("dtype", None) or draw(hh.shared_dtypes)
436-
return draw(xps.from_dtype(dtype))
437-
438-
439-
@given(
440-
x=xps.arrays(dtype=hh.shared_dtypes, shape=hh.shapes()),
441-
fill_value=full_like_fill_values(),
442-
kw=st.shared(hh.kwargs(dtype=st.none() | xps.scalar_dtypes()), key="full_like_kw"),
443-
)
444-
def test_full_like(x, fill_value, kw):
430+
@given(kw=hh.kwargs(dtype=st.none() | xps.scalar_dtypes()), data=st.data())
431+
def test_full_like(kw, data):
432+
dtype = kw.get("dtype", None) or data.draw(xps.scalar_dtypes(), label="dtype")
433+
x = data.draw(xps.arrays(dtype=dtype, shape=hh.shapes()), label="x")
434+
fill_value = data.draw(xps.from_dtype(dtype), label="fill_value")
445435
out = xp.full_like(x, fill_value, **kw)
446436
dtype = kw.get("dtype", None) or x.dtype
447437
if kw.get("dtype", None) is None:
@@ -551,7 +541,7 @@ def test_ones(shape, kw):
551541

552542

553543
@given(
554-
x=xps.arrays(dtype=hh.dtypes, shape=hh.shapes()),
544+
x=xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes()),
555545
kw=hh.kwargs(dtype=st.none() | xps.scalar_dtypes()),
556546
)
557547
def test_ones_like(x, kw):
@@ -589,7 +579,7 @@ def test_zeros(shape, kw):
589579

590580

591581
@given(
592-
x=xps.arrays(dtype=hh.dtypes, shape=hh.shapes()),
582+
x=xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes()),
593583
kw=hh.kwargs(dtype=st.none() | xps.scalar_dtypes()),
594584
)
595585
def test_zeros_like(x, kw):

0 commit comments

Comments
 (0)