Skip to content

Commit e5667c5

Browse files
committed
Filter undefined dtypes in mutually_promotable_dtypes
1 parent 797537e commit e5667c5

File tree

2 files changed

+22
-13
lines changed

2 files changed

+22
-13
lines changed

array_api_tests/hypothesis_helpers.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from operator import mul
33
from math import sqrt
44
import itertools
5-
from typing import Tuple, Optional, List
5+
from typing import Tuple, Optional, List, Sequence
66

77
from hypothesis import assume
88
from hypothesis.strategies import (lists, integers, sampled_from,
@@ -69,19 +69,14 @@ def _dtypes_sorter(dtype_pair: Tuple[DataType, DataType]):
6969

7070
promotable_dtypes: List[Tuple[DataType, DataType]] = sorted(dh.promotion_table.keys(), key=_dtypes_sorter)
7171

72-
if FILTER_UNDEFINED_DTYPES:
73-
promotable_dtypes = [
74-
(i, j) for i, j in promotable_dtypes
75-
if not isinstance(i, _UndefinedStub)
76-
and not isinstance(j, _UndefinedStub)
77-
]
78-
79-
8072
def mutually_promotable_dtypes(
8173
max_size: Optional[int] = 2,
8274
*,
83-
dtypes: Tuple[DataType, ...] = dh.all_dtypes,
75+
dtypes: Sequence[DataType] = dh.all_dtypes,
8476
) -> SearchStrategy[Tuple[DataType, ...]]:
77+
if FILTER_UNDEFINED_DTYPES:
78+
dtypes = [d for d in dtypes if not isinstance(d, _UndefinedStub)]
79+
assert len(dtypes) > 0, "all dtypes undefined" # sanity check
8580
if max_size == 2:
8681
return sampled_from(
8782
[(i, j) for i, j in promotable_dtypes if i in dtypes and j in dtypes]
@@ -348,7 +343,7 @@ def multiaxis_indices(draw, shapes):
348343

349344

350345
def two_mutual_arrays(
351-
dtypes: Tuple[DataType, ...] = dh.all_dtypes,
346+
dtypes: Sequence[DataType] = dh.all_dtypes,
352347
two_shapes: SearchStrategy[Tuple[Shape, Shape]] = two_mutually_broadcastable_shapes,
353348
) -> SearchStrategy:
354349
mutual_dtypes = shared(mutually_promotable_dtypes(dtypes=dtypes))

array_api_tests/meta/test_hypothesis_helpers.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,29 @@
1414
pytestmark = [pytest.mark.skipif(UNDEFINED_DTYPES, reason="undefined dtypes")]
1515

1616
@given(hh.mutually_promotable_dtypes(dtypes=dh.float_dtypes))
17-
def test_mutually_promotable_dtypes(pairs):
18-
assert pairs in (
17+
def test_mutually_promotable_dtypes(pair):
18+
assert pair in (
1919
(xp.float32, xp.float32),
2020
(xp.float32, xp.float64),
2121
(xp.float64, xp.float32),
2222
(xp.float64, xp.float64),
2323
)
2424

2525

26+
@given(
27+
hh.mutually_promotable_dtypes(
28+
dtypes=[xp.uint8, _UndefinedStub("uint16"), xp.uint32]
29+
)
30+
)
31+
def test_partial_mutually_promotable_dtypes(pair):
32+
assert pair in (
33+
(xp.uint8, xp.uint8),
34+
(xp.uint8, xp.uint32),
35+
(xp.uint32, xp.uint8),
36+
(xp.uint32, xp.uint32),
37+
)
38+
39+
2640
def valid_shape(shape) -> bool:
2741
return (
2842
all(isinstance(side, int) for side in shape)

0 commit comments

Comments
 (0)