Skip to content

Commit 47cd634

Browse files
committed
Do not call tensordot when general_dot is just an elemwise multiplication
1 parent cbffae1 commit 47cd634

File tree

1 file changed

+7
-9
lines changed

1 file changed

+7
-9
lines changed

pytensor/tensor/einsum.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -131,15 +131,13 @@ def _general_dot(
131131
# TODO: tensordot produces very complicated graphs unnecessarily
132132
# In some cases we are just doing elemwise addition after some transpositions
133133
# We also have some Blockwise(Reshape) that will slow down things!
134-
out = vectorize(
135-
partial(tensordot, axes=[core_lhs_axes, core_rhs_axes]), signature=signature
136-
)(lhs, rhs)
137-
138-
# # Reorder batch axes according to the original order of lhs
139-
# original_lhs_batch_axes, _ = batch_axes
140-
# final_batch_axes = tuple(np.argsort(original_lhs_batch_axes))
141-
# new_batch_axes = tuple(range(lhs_n_batch_axes))
142-
# out = moveaxis(out, new_batch_axes, final_batch_axes)
134+
if signature == "(),()->()":
135+
# Just a multiplication
136+
out = lhs * rhs
137+
else:
138+
out = vectorize(
139+
partial(tensordot, axes=[core_lhs_axes, core_rhs_axes]), signature=signature
140+
)(lhs, rhs)
143141

144142
return cast(TensorVariable, out)
145143

0 commit comments

Comments
 (0)