|
2 | 2 | from operator import mul
|
3 | 3 | from math import sqrt
|
4 | 4 | import itertools
|
5 |
| -from typing import Tuple, Optional, List |
| 5 | +from typing import Tuple, Optional, List, Sequence |
6 | 6 |
|
7 | 7 | from hypothesis import assume
|
8 | 8 | from hypothesis.strategies import (lists, integers, sampled_from,
|
@@ -69,19 +69,14 @@ def _dtypes_sorter(dtype_pair: Tuple[DataType, DataType]):
|
69 | 69 |
|
70 | 70 | promotable_dtypes: List[Tuple[DataType, DataType]] = sorted(dh.promotion_table.keys(), key=_dtypes_sorter)
|
71 | 71 |
|
72 |
| -if FILTER_UNDEFINED_DTYPES: |
73 |
| - promotable_dtypes = [ |
74 |
| - (i, j) for i, j in promotable_dtypes |
75 |
| - if not isinstance(i, _UndefinedStub) |
76 |
| - and not isinstance(j, _UndefinedStub) |
77 |
| - ] |
78 |
| - |
79 |
| - |
80 | 72 | def mutually_promotable_dtypes(
|
81 | 73 | max_size: Optional[int] = 2,
|
82 | 74 | *,
|
83 |
| - dtypes: Tuple[DataType, ...] = dh.all_dtypes, |
| 75 | + dtypes: Sequence[DataType] = dh.all_dtypes, |
84 | 76 | ) -> SearchStrategy[Tuple[DataType, ...]]:
|
| 77 | + if FILTER_UNDEFINED_DTYPES: |
| 78 | + dtypes = [d for d in dtypes if not isinstance(d, _UndefinedStub)] |
| 79 | + assert len(dtypes) > 0, "all dtypes undefined" # sanity check |
85 | 80 | if max_size == 2:
|
86 | 81 | return sampled_from(
|
87 | 82 | [(i, j) for i, j in promotable_dtypes if i in dtypes and j in dtypes]
|
@@ -348,7 +343,7 @@ def multiaxis_indices(draw, shapes):
|
348 | 343 |
|
349 | 344 |
|
350 | 345 | def two_mutual_arrays(
|
351 |
| - dtypes: Tuple[DataType, ...] = dh.all_dtypes, |
| 346 | + dtypes: Sequence[DataType] = dh.all_dtypes, |
352 | 347 | two_shapes: SearchStrategy[Tuple[Shape, Shape]] = two_mutually_broadcastable_shapes,
|
353 | 348 | ) -> SearchStrategy:
|
354 | 349 | mutual_dtypes = shared(mutually_promotable_dtypes(dtypes=dtypes))
|
|
0 commit comments