Skip to content

Commit 75b445b

Browse files
committed
broadcast the two input shapes for transposed matmul
1 parent 81f2dab commit 75b445b

File tree

2 files changed

+11
-0
lines changed

2 files changed

+11
-0
lines changed

py/torch_tensorrt/fx/converters/acc_ops_converters.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,14 @@ def trt_transposed_matmul_converter(network, target, args, kwargs, name):
3838
lhs = get_trt_tensor(network, lhs, f"{name}_lhs")
3939
if isinstance(rhs, torch.nn.Parameter):
4040
rhs = get_trt_tensor(network, rhs, f"{name}_rhs")
41+
42+
lhs, rhs = broadcast(
43+
network,
44+
lhs,
45+
rhs,
46+
f"{lhs.name}_broadcast",
47+
f"{rhs.name}_broadcast",
48+
)
4149
layer = network.add_matrix_multiply(
4250
lhs,
4351
trt.MatrixOperation.TRANSPOSE if lhs_transposed else trt.MatrixOperation.NONE,

py/torch_tensorrt/fx/test/passes/test_fuse_permute_matmul_trt.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ class TestFusePermuteMatmul(AccTestCase):
3737
lambda x: x.permute(0, 1, 3, 2),
3838
torch.matmul,
3939
),
40+
param("transpose_lhs_bmm_broadcast", (3, 2), (3, 3, 4), tranpose_last_two_dims, op=torch.matmul),
41+
param("transpose_rhs_bmm_broadcast", (3, 3, 4), (3, 4), rhs_op=tranpose_last_two_dims, op=torch.matmul),
4042
]
4143
)
4244
def test_fuse_permute_matmul(
@@ -58,6 +60,7 @@ def forward(self, x, y):
5860
inputs,
5961
{trt_transposed_matmul},
6062
apply_passes=[fuse_permute_matmul],
63+
test_implicit_batch_dim=(len(lhs_shape) == len(rhs_shape)),
6164
)
6265

6366
@parameterized.expand(

0 commit comments

Comments
 (0)