Skip to content

Commit c92dd09

Browse files
committed
Filter undefined dtypes in mutually_promotable_dtypes
1 parent 0574111 commit c92dd09

File tree

2 files changed

+22
-13
lines changed

2 files changed

+22
-13
lines changed

array_api_tests/hypothesis_helpers.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from functools import reduce
33
from math import sqrt
44
from operator import mul
5-
from typing import Any, List, NamedTuple, Optional, Tuple
5+
from typing import Any, List, NamedTuple, Optional, Tuple, Sequence
66

77
from hypothesis import assume
88
from hypothesis.strategies import (SearchStrategy, booleans, composite, floats,
@@ -68,19 +68,14 @@ def _dtypes_sorter(dtype_pair: Tuple[DataType, DataType]):
6868

6969
promotable_dtypes: List[Tuple[DataType, DataType]] = sorted(dh.promotion_table.keys(), key=_dtypes_sorter)
7070

71-
if FILTER_UNDEFINED_DTYPES:
72-
promotable_dtypes = [
73-
(i, j) for i, j in promotable_dtypes
74-
if not isinstance(i, _UndefinedStub)
75-
and not isinstance(j, _UndefinedStub)
76-
]
77-
78-
7971
def mutually_promotable_dtypes(
8072
max_size: Optional[int] = 2,
8173
*,
82-
dtypes: Tuple[DataType, ...] = dh.all_dtypes,
74+
dtypes: Sequence[DataType] = dh.all_dtypes,
8375
) -> SearchStrategy[Tuple[DataType, ...]]:
76+
if FILTER_UNDEFINED_DTYPES:
77+
dtypes = [d for d in dtypes if not isinstance(d, _UndefinedStub)]
78+
assert len(dtypes) > 0, "all dtypes undefined" # sanity check
8479
if max_size == 2:
8580
return sampled_from(
8681
[(i, j) for i, j in promotable_dtypes if i in dtypes and j in dtypes]
@@ -347,7 +342,7 @@ def multiaxis_indices(draw, shapes):
347342

348343

349344
def two_mutual_arrays(
350-
dtypes: Tuple[DataType, ...] = dh.all_dtypes,
345+
dtypes: Sequence[DataType] = dh.all_dtypes,
351346
two_shapes: SearchStrategy[Tuple[Shape, Shape]] = two_mutually_broadcastable_shapes,
352347
) -> SearchStrategy:
353348
mutual_dtypes = shared(mutually_promotable_dtypes(dtypes=dtypes))

array_api_tests/meta/test_hypothesis_helpers.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,29 @@
1515
pytestmark = [pytest.mark.skipif(UNDEFINED_DTYPES, reason="undefined dtypes")]
1616

1717
@given(hh.mutually_promotable_dtypes(dtypes=dh.float_dtypes))
18-
def test_mutually_promotable_dtypes(pairs):
19-
assert pairs in (
18+
def test_mutually_promotable_dtypes(pair):
19+
assert pair in (
2020
(xp.float32, xp.float32),
2121
(xp.float32, xp.float64),
2222
(xp.float64, xp.float32),
2323
(xp.float64, xp.float64),
2424
)
2525

2626

27+
@given(
28+
hh.mutually_promotable_dtypes(
29+
dtypes=[xp.uint8, _UndefinedStub("uint16"), xp.uint32]
30+
)
31+
)
32+
def test_partial_mutually_promotable_dtypes(pair):
33+
assert pair in (
34+
(xp.uint8, xp.uint8),
35+
(xp.uint8, xp.uint32),
36+
(xp.uint32, xp.uint8),
37+
(xp.uint32, xp.uint32),
38+
)
39+
40+
2741
def valid_shape(shape) -> bool:
2842
return (
2943
all(isinstance(side, int) for side in shape)

0 commit comments

Comments
 (0)