Skip to content

Commit bfa971a

Browse files
author
XingFei Xi
committed
broadcast the two input shapes for transposed matmul
1 parent a9a4bb2 commit bfa971a

File tree

1 file changed

+8
-0
lines changed

1 file changed

+8
-0
lines changed

py/torch_tensorrt/fx/converters/acc_ops_converters.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,14 @@ def trt_transposed_matmul_converter(network, target, args, kwargs, name):
3737
lhs = get_trt_tensor(network, lhs, f"{name}_lhs")
3838
if isinstance(rhs, torch.nn.Parameter):
3939
rhs = get_trt_tensor(network, rhs, f"{name}_rhs")
40+
41+
lhs, rhs = broadcast(
42+
network,
43+
lhs,
44+
rhs,
45+
f"{lhs.name}_broadcast",
46+
f"{rhs.name}_broadcast",
47+
)
4048
layer = network.add_matrix_multiply(
4149
lhs,
4250
trt.MatrixOperation.TRANSPOSE if lhs_transposed else trt.MatrixOperation.NONE,

0 commit comments

Comments
 (0)