@@ -81,7 +81,7 @@ def test_func_returns_array_with_correct_dtype(
81
81
x = data .draw (
82
82
xps .arrays (dtype = in_dtypes [0 ], shape = hh .shapes ).filter (x_filter ), label = 'x'
83
83
)
84
- arrays = [ x ]
84
+ out = func ( x )
85
85
else :
86
86
arrays = []
87
87
shapes = data .draw (
@@ -92,10 +92,10 @@ def test_func_returns_array_with_correct_dtype(
92
92
xps .arrays (dtype = dtype , shape = shape ).filter (x_filter ), label = f'x{ i } '
93
93
)
94
94
arrays .append (x )
95
- try :
96
- out = func (* arrays )
97
- except OverflowError :
98
- reject ()
95
+ try :
96
+ out = func (* arrays )
97
+ except OverflowError :
98
+ reject ()
99
99
assert out .dtype == out_dtype , f'{ out .dtype = !s} , but should be { out_dtype } '
100
100
101
101
@@ -147,23 +147,24 @@ def gen_op_params() -> Iterator[Tuple[str, Tuple[DT, ...], DT, Callable]]:
147
147
def test_operator_returns_array_with_correct_dtype (
148
148
expr , in_dtypes , out_dtype , x_filter , data
149
149
):
150
- locals_ = {}
151
150
if len (in_dtypes ) == 1 :
152
- locals_ [ 'x' ] = data .draw (
151
+ x = data .draw (
153
152
xps .arrays (dtype = in_dtypes [0 ], shape = hh .shapes ).filter (x_filter ), label = 'x'
154
153
)
154
+ out = eval (expr , {'x' : x })
155
155
else :
156
+ locals_ = {}
156
157
shapes = data .draw (
157
158
hh .mutually_broadcastable_shapes (len (in_dtypes )), label = 'shapes'
158
159
)
159
160
for i , (dtype , shape ) in enumerate (zip (in_dtypes , shapes ), 1 ):
160
161
locals_ [f'x{ i } ' ] = data .draw (
161
162
xps .arrays (dtype = dtype , shape = shape ).filter (x_filter ), label = f'x{ i } '
162
163
)
163
- try :
164
- out = eval (expr , locals_ )
165
- except OverflowError :
166
- reject ()
164
+ try :
165
+ out = eval (expr , locals_ )
166
+ except OverflowError :
167
+ reject ()
167
168
assert out .dtype == out_dtype , f'{ out .dtype = !s} , but should be { out_dtype } '
168
169
169
170
0 commit comments