4
4
5
5
from hypothesis .strategies import (lists , integers , sampled_from ,
6
6
shared , floats , just , composite , one_of ,
7
- none , booleans , SearchStrategy )
7
+ none , booleans )
8
8
from hypothesis .extra .array_api import make_strategies_namespace
9
9
from hypothesis import assume
10
10
46
46
boolean_dtypes = boolean_dtypes .filter (lambda x : not isinstance (x , _UndefinedStub ))
47
47
dtypes = dtypes .filter (lambda x : not isinstance (x , _UndefinedStub ))
48
48
49
- shared_dtypes = shared (dtypes )
49
+ shared_dtypes = shared (dtypes , key = "dtype" )
50
50
51
- def make_dtype_pairs ():
51
+ # TODO: Importing things from test_type_promotion should be replaced by
52
+ # something that won't cause a circular import. Right now we use @st.composite
53
+ # only because it returns a lazy-evaluated strategy - in the future this method
54
+ # should remove the composite wrapper, just returning sampled_from(dtype_pairs)
55
+ # instead of drawing from it.
56
+ @composite
57
+ def mutually_promotable_dtype_pairs (draw , dtype_objects = dtype_objects ):
52
58
from .test_type_promotion import dtype_mapping , promotion_table
53
59
# sort for shrinking (sampled_from shrinks to the earlier elements in the
54
60
# list). Give pairs of the same dtypes first, then smaller dtypes,
@@ -66,19 +72,12 @@ def make_dtype_pairs():
66
72
dtype_pairs = [(i , j ) for i , j in dtype_pairs
67
73
if not isinstance (i , _UndefinedStub )
68
74
and not isinstance (j , _UndefinedStub )]
69
- return dtype_pairs
70
-
71
- def promotable_dtypes (dtype ):
72
- if isinstance (dtype , SearchStrategy ):
73
- return dtype .flatmap (promotable_dtypes )
74
- dtype_pairs = make_dtype_pairs ()
75
- dtypes = [j for i , j in dtype_pairs if i == dtype ]
76
- return sampled_from (dtypes )
77
-
78
- def mutually_promotable_dtype_pairs (dtype_objects = dtype_objects ):
79
- dtype_pairs = make_dtype_pairs ()
80
75
dtype_pairs = [(i , j ) for i , j in dtype_pairs if i in dtype_objects and j in dtype_objects ]
81
- return sampled_from (dtype_pairs )
76
+ return draw (sampled_from (dtype_pairs ))
77
+
78
+ shared_mutually_promotable_dtype_pairs = shared (
79
+ mutually_promotable_dtype_pairs (), key = "mutually_promotable_dtype_pair"
80
+ )
82
81
83
82
# shared() allows us to draw either the function or the function name and they
84
83
# will both correspond to the same function.
@@ -245,10 +244,10 @@ def multiaxis_indices(draw, shapes):
245
244
246
245
247
246
shared_arrays1 = xps .arrays (
248
- dtype = shared_dtypes ,
247
+ dtype = shared_mutually_promotable_dtype_pairs . map ( lambda pair : pair [ 0 ]) ,
249
248
shape = shared (two_mutually_broadcastable_shapes , key = "shape_pair" ).map (lambda pair : pair [0 ]),
250
249
)
251
250
shared_arrays2 = xps .arrays (
252
- dtype = promotable_dtypes ( shared_dtypes ),
251
+ dtype = shared_mutually_promotable_dtype_pairs . map ( lambda pair : pair [ 1 ] ),
253
252
shape = shared (two_mutually_broadcastable_shapes , key = "shape_pair" ).map (lambda pair : pair [1 ]),
254
253
)
0 commit comments