Skip to content

Commit 69d371f

Browse files
author
XingFei Xi
committed
broadcast the two input shapes for transposed matmul
modified: py/torch_tensorrt/fx/converters/acc_ops_converters.py modified: py/torch_tensorrt/fx/test/passes/test_fuse_permute_matmul_trt.py
1 parent 158be87 commit 69d371f

File tree

2 files changed

+11
-0
lines changed

2 files changed

+11
-0
lines changed

py/torch_tensorrt/fx/converters/acc_ops_converters.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,14 @@ def trt_transposed_matmul_converter(network, target, args, kwargs, name):
3737
lhs = get_trt_tensor(network, lhs, f"{name}_lhs")
3838
if isinstance(rhs, torch.nn.Parameter):
3939
rhs = get_trt_tensor(network, rhs, f"{name}_rhs")
40+
41+
lhs, rhs = broadcast(
42+
network,
43+
lhs,
44+
rhs,
45+
f"{lhs.name}_broadcast",
46+
f"{rhs.name}_broadcast",
47+
)
4048
layer = network.add_matrix_multiply(
4149
lhs,
4250
trt.MatrixOperation.TRANSPOSE if lhs_transposed else trt.MatrixOperation.NONE,

py/torch_tensorrt/fx/test/passes/test_fuse_permute_matmul_trt.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ class TestFusePermuteMatmul(AccTestCase):
3737
lambda x: x.permute(0, 1, 3, 2),
3838
torch.matmul,
3939
),
40+
param("transpose_lhs_bmm_broadcast", (3, 2), (3, 3, 4), tranpose_last_two_dims, op=torch.matmul),
41+
param("transpose_rhs_bmm_broadcast", (3, 3, 4), (3, 4), rhs_op=tranpose_last_two_dims, op=torch.matmul),
4042
]
4143
)
4244
def test_fuse_permute_matmul(
@@ -58,6 +60,7 @@ def forward(self, x, y):
5860
inputs,
5961
{trt_transposed_matmul},
6062
apply_passes=[fuse_permute_matmul],
63+
test_implicit_batch_dim=(len(lhs_shape) == len(rhs_shape)),
6164
)
6265

6366
@parameterized.expand(

0 commit comments

Comments
 (0)