14
14
from pytensor .tensor .shape import Reshape
15
15
16
16
17
+ floatX = pytensor .config .floatX
18
+ ATOL = RTOL = 1e-8 if floatX == "float64" else 1e-4
19
+
20
+
17
21
def assert_no_blockwise_in_graph (fgraph : FunctionGraph , core_op = None ) -> None :
18
22
for node in fgraph .apply_nodes :
19
23
if isinstance (node .op , Blockwise ):
@@ -79,11 +83,10 @@ def test_general_dot():
79
83
np_batched_tensordot = np .vectorize (
80
84
partial (np .tensordot , axes = tensordot_axes ), signature = signature
81
85
)
82
- x_test = rng .normal (size = x .type .shape )
83
- y_test = rng .normal (size = y .type .shape )
86
+ x_test = rng .normal (size = x .type .shape ). astype ( floatX )
87
+ y_test = rng .normal (size = y .type .shape ). astype ( floatX )
84
88
np .testing .assert_allclose (
85
- fn (x_test , y_test ),
86
- np_batched_tensordot (x_test , y_test ),
89
+ fn (x_test , y_test ), np_batched_tensordot (x_test , y_test ), atol = ATOL , rtol = RTOL
87
90
)
88
91
89
92
@@ -130,7 +133,7 @@ def test_einsum_signatures(static_shape_known, signature):
130
133
assert out .owner .op .optimize == static_shape_known or len (operands ) <= 2
131
134
132
135
rng = np .random .default_rng (37 )
133
- test_values = [rng .normal (size = shape ) for shape in shapes ]
136
+ test_values = [rng .normal (size = shape ). astype ( floatX ) for shape in shapes ]
134
137
np_out = np .einsum (signature , * test_values )
135
138
136
139
fn = function (operands , out )
@@ -139,7 +142,7 @@ def test_einsum_signatures(static_shape_known, signature):
139
142
# print(); fn.dprint(print_type=True)
140
143
141
144
assert_no_blockwise_in_graph (fn .maker .fgraph )
142
- np .testing .assert_allclose (pt_out , np_out )
145
+ np .testing .assert_allclose (pt_out , np_out , atol = ATOL , rtol = RTOL )
143
146
144
147
145
148
def test_batch_dim ():
@@ -165,40 +168,49 @@ def test_einsum_conv():
165
168
conv_signature = "bchwkt,fckt->bfhw"
166
169
windowed_input = rng .random (
167
170
size = (batch_size , channels , height , width , kernel_size , kernel_size )
171
+ ).astype (floatX )
172
+ weights = rng .random (size = (num_filters , channels , kernel_size , kernel_size )).astype (
173
+ floatX
168
174
)
169
- weights = rng .random (size = (num_filters , channels , kernel_size , kernel_size ))
170
175
result = einsum (conv_signature , windowed_input , weights ).eval ()
171
176
172
177
assert result .shape == (32 , 15 , 8 , 8 )
173
178
np .testing .assert_allclose (
174
179
result ,
175
180
np .einsum ("bchwkt,fckt->bfhw" , windowed_input , weights ),
181
+ atol = ATOL ,
182
+ rtol = RTOL ,
176
183
)
177
184
178
185
179
186
def test_ellipsis ():
180
187
rng = np .random .default_rng (159 )
181
188
x = pt .tensor ("x" , shape = (3 , 5 , 7 , 11 ))
182
189
y = pt .tensor ("y" , shape = (3 , 5 , 11 , 13 ))
183
- x_test = rng .normal (size = x .type .shape )
184
- y_test = rng .normal (size = y .type .shape )
190
+ x_test = rng .normal (size = x .type .shape ). astype ( floatX )
191
+ y_test = rng .normal (size = y .type .shape ). astype ( floatX )
185
192
expected_out = np .matmul (x_test , y_test )
186
193
187
194
with pytest .raises (ValueError ):
188
195
pt .einsum ("mp,pn->mn" , x , y )
189
196
190
197
out = pt .einsum ("...mp,...pn->...mn" , x , y )
191
- np .testing .assert_allclose (out .eval ({x : x_test , y : y_test }), expected_out )
198
+ np .testing .assert_allclose (
199
+ out .eval ({x : x_test , y : y_test }), expected_out , atol = ATOL , rtol = RTOL
200
+ )
192
201
193
202
# Put batch axes in the middle
194
203
new_x = pt .moveaxis (x , - 2 , 0 )
195
204
new_y = pt .moveaxis (y , - 2 , 0 )
196
205
out = pt .einsum ("m...p,p...n->m...n" , new_x , new_y )
197
206
np .testing .assert_allclose (
198
- out .eval ({x : x_test , y : y_test }), expected_out .transpose (- 2 , 0 , 1 , - 1 )
207
+ out .eval ({x : x_test , y : y_test }),
208
+ expected_out .transpose (- 2 , 0 , 1 , - 1 ),
209
+ atol = ATOL ,
210
+ rtol = RTOL ,
199
211
)
200
212
201
213
out = pt .einsum ("m...p,p...n->mn" , new_x , new_y )
202
214
np .testing .assert_allclose (
203
- out .eval ({x : x_test , y : y_test }), expected_out .sum ((0 , 1 ))
215
+ out .eval ({x : x_test , y : y_test }), expected_out .sum ((0 , 1 )), atol = ATOL , rtol = RTOL
204
216
)
0 commit comments