Skip to content

Commit f94e02c

Browse files
jessegrabowskiricardoV94
authored andcommitted
Add float32 test support
1 parent 73e10e0 commit f94e02c

File tree

1 file changed

+24
-12
lines changed

1 file changed

+24
-12
lines changed

tests/tensor/test_einsum.py

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,10 @@
1414
from pytensor.tensor.shape import Reshape
1515

1616

17+
floatX = pytensor.config.floatX
18+
ATOL = RTOL = 1e-8 if floatX == "float64" else 1e-4
19+
20+
1721
def assert_no_blockwise_in_graph(fgraph: FunctionGraph, core_op=None) -> None:
1822
for node in fgraph.apply_nodes:
1923
if isinstance(node.op, Blockwise):
@@ -79,11 +83,10 @@ def test_general_dot():
7983
np_batched_tensordot = np.vectorize(
8084
partial(np.tensordot, axes=tensordot_axes), signature=signature
8185
)
82-
x_test = rng.normal(size=x.type.shape)
83-
y_test = rng.normal(size=y.type.shape)
86+
x_test = rng.normal(size=x.type.shape).astype(floatX)
87+
y_test = rng.normal(size=y.type.shape).astype(floatX)
8488
np.testing.assert_allclose(
85-
fn(x_test, y_test),
86-
np_batched_tensordot(x_test, y_test),
89+
fn(x_test, y_test), np_batched_tensordot(x_test, y_test), atol=ATOL, rtol=RTOL
8790
)
8891

8992

@@ -130,7 +133,7 @@ def test_einsum_signatures(static_shape_known, signature):
130133
assert out.owner.op.optimize == static_shape_known or len(operands) <= 2
131134

132135
rng = np.random.default_rng(37)
133-
test_values = [rng.normal(size=shape) for shape in shapes]
136+
test_values = [rng.normal(size=shape).astype(floatX) for shape in shapes]
134137
np_out = np.einsum(signature, *test_values)
135138

136139
fn = function(operands, out)
@@ -139,7 +142,7 @@ def test_einsum_signatures(static_shape_known, signature):
139142
# print(); fn.dprint(print_type=True)
140143

141144
assert_no_blockwise_in_graph(fn.maker.fgraph)
142-
np.testing.assert_allclose(pt_out, np_out)
145+
np.testing.assert_allclose(pt_out, np_out, atol=ATOL, rtol=RTOL)
143146

144147

145148
def test_batch_dim():
@@ -165,40 +168,49 @@ def test_einsum_conv():
165168
conv_signature = "bchwkt,fckt->bfhw"
166169
windowed_input = rng.random(
167170
size=(batch_size, channels, height, width, kernel_size, kernel_size)
171+
).astype(floatX)
172+
weights = rng.random(size=(num_filters, channels, kernel_size, kernel_size)).astype(
173+
floatX
168174
)
169-
weights = rng.random(size=(num_filters, channels, kernel_size, kernel_size))
170175
result = einsum(conv_signature, windowed_input, weights).eval()
171176

172177
assert result.shape == (32, 15, 8, 8)
173178
np.testing.assert_allclose(
174179
result,
175180
np.einsum("bchwkt,fckt->bfhw", windowed_input, weights),
181+
atol=ATOL,
182+
rtol=RTOL,
176183
)
177184

178185

179186
def test_ellipsis():
180187
rng = np.random.default_rng(159)
181188
x = pt.tensor("x", shape=(3, 5, 7, 11))
182189
y = pt.tensor("y", shape=(3, 5, 11, 13))
183-
x_test = rng.normal(size=x.type.shape)
184-
y_test = rng.normal(size=y.type.shape)
190+
x_test = rng.normal(size=x.type.shape).astype(floatX)
191+
y_test = rng.normal(size=y.type.shape).astype(floatX)
185192
expected_out = np.matmul(x_test, y_test)
186193

187194
with pytest.raises(ValueError):
188195
pt.einsum("mp,pn->mn", x, y)
189196

190197
out = pt.einsum("...mp,...pn->...mn", x, y)
191-
np.testing.assert_allclose(out.eval({x: x_test, y: y_test}), expected_out)
198+
np.testing.assert_allclose(
199+
out.eval({x: x_test, y: y_test}), expected_out, atol=ATOL, rtol=RTOL
200+
)
192201

193202
# Put batch axes in the middle
194203
new_x = pt.moveaxis(x, -2, 0)
195204
new_y = pt.moveaxis(y, -2, 0)
196205
out = pt.einsum("m...p,p...n->m...n", new_x, new_y)
197206
np.testing.assert_allclose(
198-
out.eval({x: x_test, y: y_test}), expected_out.transpose(-2, 0, 1, -1)
207+
out.eval({x: x_test, y: y_test}),
208+
expected_out.transpose(-2, 0, 1, -1),
209+
atol=ATOL,
210+
rtol=RTOL,
199211
)
200212

201213
out = pt.einsum("m...p,p...n->mn", new_x, new_y)
202214
np.testing.assert_allclose(
203-
out.eval({x: x_test, y: y_test}), expected_out.sum((0, 1))
215+
out.eval({x: x_test, y: y_test}), expected_out.sum((0, 1)), atol=ATOL, rtol=RTOL
204216
)

0 commit comments

Comments
 (0)