Skip to content

Commit a8993d6

Browse files
Fix _general_dot doctest
1 parent 5ba3cc7 commit a8993d6

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

pytensor/tensor/einsum.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -214,11 +214,16 @@ def _general_dot(
214214
215215
import pytensor.tensor as pt
216216
from pytensor.tensor.einsum import _general_dot
217+
import numpy as np
218+
217219
A = pt.tensor(shape = (3, 4, 5))
218220
B = pt.tensor(shape = (3, 5, 2))
219221
220222
result = _general_dot((A, B), axes=[[2], [1]], batch_axes=[[0], [0]])
221-
print(result.type.shape)
223+
224+
A_val = np.empty((3, 4, 5))
225+
B_val = np.empty((3, 5, 2))
226+
print(result.shape.eval({A:A_val, B:B_val}))
222227
223228
.. testoutput::
224229

0 commit comments

Comments
 (0)