Skip to content

Commit 5c6ffa4

Browse files
author
XingFei Xi
committed
broadcast the two input shapes for transposed matmul
modified: py/torch_tensorrt/fx/converters/acc_ops_converters.py modified: py/torch_tensorrt/fx/test/passes/test_fuse_permute_matmul_trt.py
1 parent 5d9ddaa commit 5c6ffa4

File tree

2 files changed

+13
-0
lines changed

2 files changed

+13
-0
lines changed

py/torch_tensorrt/fx/converters/acc_ops_converters.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,14 @@ def trt_transposed_matmul_converter(network, target, args, kwargs, name):
3737
lhs = get_trt_tensor(network, lhs, f"{name}_lhs")
3838
if isinstance(rhs, torch.nn.Parameter):
3939
rhs = get_trt_tensor(network, rhs, f"{name}_rhs")
40+
41+
lhs, rhs = broadcast(
42+
network,
43+
lhs,
44+
rhs,
45+
f"{lhs.name}_broadcast",
46+
f"{rhs.name}_broadcast",
47+
)
4048
layer = network.add_matrix_multiply(
4149
lhs,
4250
trt.MatrixOperation.TRANSPOSE if lhs_transposed else trt.MatrixOperation.NONE,

py/torch_tensorrt/fx/test/passes/test_fuse_permute_matmul_trt.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ def permute021(x):
2222
class TestFusePermuteMatmul(AccTestCase):
2323
@parameterized.expand(
2424
[
25+
("transpose_lhs_bmm_broadcast", (3, 2), (3, 3, 4), tranpose_last_two_dims),
2526
("transpose_lhs_bmm", (3, 3, 2), (3, 3, 4), tranpose_last_two_dims),
2627
param(
2728
"transpose_rhs_bmm", (3, 2, 3), (3, 4, 3), rhs_op=tranpose_last_two_dims
@@ -37,6 +38,8 @@ class TestFusePermuteMatmul(AccTestCase):
3738
lambda x: x.permute(0, 1, 3, 2),
3839
torch.matmul,
3940
),
41+
param("transpose_lhs_bmm_broadcast", (3, 2), (3, 3, 4), tranpose_last_two_dims, op=torch.matmul),
42+
param("transpose_rhs_bmm_broadcast", (3, 3, 4), (3, 4), rhs_op=tranpose_last_two_dims, op=torch.matmul),
4043
]
4144
)
4245
def test_fuse_permute_matmul(
@@ -53,11 +56,13 @@ def forward(self, x, y):
5356
return op(lhs_op(x), rhs_op(y))
5457

5558
inputs = [torch.randn(*lhs_shape), torch.randn(*rhs_shape)]
59+
print("!!!!!!!!! come")
5660
self.run_test(
5761
TestModule(),
5862
inputs,
5963
{trt_transposed_matmul},
6064
apply_passes=[fuse_permute_matmul],
65+
test_implicit_batch_dim=(len(lhs_shape) == len(rhs_shape)),
6166
)
6267

6368
@parameterized.expand(

0 commit comments

Comments
 (0)