Skip to content

Commit 2b414ab

Browse files
committed
Remove promotable_dtypes and use the original mutual method
1 parent 03c634b commit 2b414ab

File tree

2 files changed

+17
-26
lines changed

2 files changed

+17
-26
lines changed

array_api_tests/hypothesis_helpers.py

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from hypothesis.strategies import (lists, integers, sampled_from,
66
shared, floats, just, composite, one_of,
7-
none, booleans, SearchStrategy)
7+
none, booleans)
88
from hypothesis.extra.array_api import make_strategies_namespace
99
from hypothesis import assume
1010

@@ -46,9 +46,15 @@
4646
boolean_dtypes = boolean_dtypes.filter(lambda x: not isinstance(x, _UndefinedStub))
4747
dtypes = dtypes.filter(lambda x: not isinstance(x, _UndefinedStub))
4848

49-
shared_dtypes = shared(dtypes)
49+
shared_dtypes = shared(dtypes, key="dtype")
5050

51-
def make_dtype_pairs():
51+
# TODO: Importing things from test_type_promotion should be replaced by
52+
# something that won't cause a circular import. Right now we use @st.composite
53+
# only because it returns a lazy-evaluated strategy - in the future this method
54+
# should remove the composite wrapper, just returning sampled_from(dtype_pairs)
55+
# instead of drawing from it.
56+
@composite
57+
def mutually_promotable_dtype_pairs(draw, dtype_objects=dtype_objects):
5258
from .test_type_promotion import dtype_mapping, promotion_table
5359
# sort for shrinking (sampled_from shrinks to the earlier elements in the
5460
# list). Give pairs of the same dtypes first, then smaller dtypes,
@@ -66,19 +72,12 @@ def make_dtype_pairs():
6672
dtype_pairs = [(i, j) for i, j in dtype_pairs
6773
if not isinstance(i, _UndefinedStub)
6874
and not isinstance(j, _UndefinedStub)]
69-
return dtype_pairs
70-
71-
def promotable_dtypes(dtype):
72-
if isinstance(dtype, SearchStrategy):
73-
return dtype.flatmap(promotable_dtypes)
74-
dtype_pairs = make_dtype_pairs()
75-
dtypes = [j for i, j in dtype_pairs if i == dtype]
76-
return sampled_from(dtypes)
77-
78-
def mutually_promotable_dtype_pairs(dtype_objects=dtype_objects):
79-
dtype_pairs = make_dtype_pairs()
8075
dtype_pairs = [(i, j) for i, j in dtype_pairs if i in dtype_objects and j in dtype_objects]
81-
return sampled_from(dtype_pairs)
76+
return draw(sampled_from(dtype_pairs))
77+
78+
shared_mutually_promotable_dtype_pairs = shared(
79+
mutually_promotable_dtype_pairs(), key="mutually_promotable_dtype_pair"
80+
)
8281

8382
# shared() allows us to draw either the function or the function name and they
8483
# will both correspond to the same function.
@@ -245,10 +244,10 @@ def multiaxis_indices(draw, shapes):
245244

246245

247246
shared_arrays1 = xps.arrays(
248-
dtype=shared_dtypes,
247+
dtype=shared_mutually_promotable_dtype_pairs.map(lambda pair: pair[0]),
249248
shape=shared(two_mutually_broadcastable_shapes, key="shape_pair").map(lambda pair: pair[0]),
250249
)
251250
shared_arrays2 = xps.arrays(
252-
dtype=promotable_dtypes(shared_dtypes),
251+
dtype=shared_mutually_promotable_dtype_pairs.map(lambda pair: pair[1]),
253252
shape=shared(two_mutually_broadcastable_shapes, key="shape_pair").map(lambda pair: pair[1]),
254253
)

array_api_tests/meta_tests/test_hypothesis_helpers.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,21 +8,13 @@
88
from ..array_helpers import dtype_objects
99
from ..hypothesis_helpers import (MAX_ARRAY_SIZE,
1010
mutually_promotable_dtype_pairs,
11-
promotable_dtypes, shapes,
12-
two_broadcastable_shapes,
11+
shapes, two_broadcastable_shapes,
1312
two_mutually_broadcastable_shapes)
1413

1514
UNDEFINED_DTYPES = any(isinstance(d, _UndefinedStub) for d in dtype_objects)
1615
pytestmark = [pytest.mark.skipif(UNDEFINED_DTYPES, reason="undefined dtypes")]
1716

1817

19-
@given(promotable_dtypes(xp.uint16))
20-
def test_promotable_dtypes(dtype):
21-
assert dtype in (
22-
xp.uint8, xp.uint16, xp.uint32, xp.uint64, xp.int8, xp.int16, xp.int32, xp.int64
23-
)
24-
25-
2618
@given(mutually_promotable_dtype_pairs([xp.float32, xp.float64]))
2719
def test_mutually_promotable_dtype_pairs(pairs):
2820
assert pairs in (

0 commit comments

Comments
 (0)