Skip to content

Commit f4fe98b

Browse files
committed
Fix python lint issue
1 parent 74f9f95 commit f4fe98b

File tree

2 files changed

+20
-3
lines changed

2 files changed

+20
-3
lines changed

py/torch_tensorrt/fx/converters/acc_ops_converters.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -413,7 +413,12 @@ def acc_ops_pad_with_slice_layer(
413413
)
414414

415415
shape = tuple(
416-
input_shape[i] + (pad[-(i - prefix_len) * 2 - 1] + pad[-(i - prefix_len) * 2 - 2] if i >= prefix_len else 0)
416+
input_shape[i]
417+
+ (
418+
pad[-(i - prefix_len) * 2 - 1] + pad[-(i - prefix_len) * 2 - 2]
419+
if i >= prefix_len
420+
else 0
421+
)
417422
for i in range(0, len(input_shape))
418423
)
419424
stride = tuple([1] * len(shape))

py/torch_tensorrt/fx/test/passes/test_fuse_permute_matmul_trt.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,20 @@ class TestFusePermuteMatmul(AccTestCase):
3737
lambda x: x.permute(0, 1, 3, 2),
3838
torch.matmul,
3939
),
40-
param("transpose_lhs_bmm_broadcast", (3, 2), (3, 3, 4), tranpose_last_two_dims, op=torch.matmul),
41-
param("transpose_rhs_bmm_broadcast", (3, 3, 4), (3, 4), rhs_op=tranpose_last_two_dims, op=torch.matmul),
40+
param(
41+
"transpose_lhs_bmm_broadcast",
42+
(3, 2),
43+
(3, 3, 4),
44+
tranpose_last_two_dims,
45+
op=torch.matmul,
46+
),
47+
param(
48+
"transpose_rhs_bmm_broadcast",
49+
(3, 3, 4),
50+
(3, 4),
51+
rhs_op=tranpose_last_two_dims,
52+
op=torch.matmul,
53+
),
4254
]
4355
)
4456
def test_fuse_permute_matmul(

0 commit comments

Comments
 (0)