|
2 | 2 | from functools import reduce
|
3 | 3 | from math import sqrt
|
4 | 4 | from operator import mul
|
5 |
| -from typing import Any, List, NamedTuple, Optional, Tuple |
| 5 | +from typing import Any, List, NamedTuple, Optional, Tuple, Sequence |
6 | 6 |
|
7 | 7 | from hypothesis import assume
|
8 | 8 | from hypothesis.strategies import (SearchStrategy, booleans, composite, floats,
|
@@ -68,19 +68,14 @@ def _dtypes_sorter(dtype_pair: Tuple[DataType, DataType]):
|
68 | 68 |
|
69 | 69 | promotable_dtypes: List[Tuple[DataType, DataType]] = sorted(dh.promotion_table.keys(), key=_dtypes_sorter)
|
70 | 70 |
|
71 |
| -if FILTER_UNDEFINED_DTYPES: |
72 |
| - promotable_dtypes = [ |
73 |
| - (i, j) for i, j in promotable_dtypes |
74 |
| - if not isinstance(i, _UndefinedStub) |
75 |
| - and not isinstance(j, _UndefinedStub) |
76 |
| - ] |
77 |
| - |
78 |
| - |
79 | 71 | def mutually_promotable_dtypes(
|
80 | 72 | max_size: Optional[int] = 2,
|
81 | 73 | *,
|
82 |
| - dtypes: Tuple[DataType, ...] = dh.all_dtypes, |
| 74 | + dtypes: Sequence[DataType] = dh.all_dtypes, |
83 | 75 | ) -> SearchStrategy[Tuple[DataType, ...]]:
|
| 76 | + if FILTER_UNDEFINED_DTYPES: |
| 77 | + dtypes = [d for d in dtypes if not isinstance(d, _UndefinedStub)] |
| 78 | + assert len(dtypes) > 0, "all dtypes undefined" # sanity check |
84 | 79 | if max_size == 2:
|
85 | 80 | return sampled_from(
|
86 | 81 | [(i, j) for i, j in promotable_dtypes if i in dtypes and j in dtypes]
|
@@ -347,7 +342,7 @@ def multiaxis_indices(draw, shapes):
|
347 | 342 |
|
348 | 343 |
|
349 | 344 | def two_mutual_arrays(
|
350 |
| - dtypes: Tuple[DataType, ...] = dh.all_dtypes, |
| 345 | + dtypes: Sequence[DataType] = dh.all_dtypes, |
351 | 346 | two_shapes: SearchStrategy[Tuple[Shape, Shape]] = two_mutually_broadcastable_shapes,
|
352 | 347 | ) -> SearchStrategy:
|
353 | 348 | mutual_dtypes = shared(mutually_promotable_dtypes(dtypes=dtypes))
|
|
0 commit comments