File tree Expand file tree Collapse file tree 1 file changed +7
-9
lines changed Expand file tree Collapse file tree 1 file changed +7
-9
lines changed Original file line number Diff line number Diff line change @@ -131,15 +131,13 @@ def _general_dot(
131
131
# TODO: tensordot produces very complicated graphs unnecessarily
132
132
# In some cases we are just doing elemwise addition after some transpositions
133
133
# We also have some Blockwise(Reshape) that will slow down things!
134
- out = vectorize (
135
- partial (tensordot , axes = [core_lhs_axes , core_rhs_axes ]), signature = signature
136
- )(lhs , rhs )
137
-
138
- # # Reorder batch axes according to the original order of lhs
139
- # original_lhs_batch_axes, _ = batch_axes
140
- # final_batch_axes = tuple(np.argsort(original_lhs_batch_axes))
141
- # new_batch_axes = tuple(range(lhs_n_batch_axes))
142
- # out = moveaxis(out, new_batch_axes, final_batch_axes)
134
+ if signature == "(),()->()" :
135
+ # Just a multiplication
136
+ out = lhs * rhs
137
+ else :
138
+ out = vectorize (
139
+ partial (tensordot , axes = [core_lhs_axes , core_rhs_axes ]), signature = signature
140
+ )(lhs , rhs )
143
141
144
142
return cast (TensorVariable , out )
145
143
You can’t perform that action at this time.
0 commit comments