Skip to content

Commit 88fed13

Browse files
authored
make padding layer converter more efficient (#1470)
1 parent aa93a12 commit 88fed13

File tree

2 files changed

+31
-31
lines changed

2 files changed

+31
-31
lines changed

py/torch_tensorrt/fx/converters/acc_ops_converters.py

Lines changed: 17 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -406,46 +406,34 @@ def acc_ops_pad_with_slice_layer(
406406
)
407407

408408
input_shape = input_val.shape
409-
pre_start = tuple(i - 1 for i in input_shape)
410409
prefix_len = len(input_shape) - len(pad) // 2
411-
pre_shape = tuple(
412-
input_shape[i] + (pad[-(i - prefix_len) * 2 - 2] if i >= prefix_len else 0)
410+
start = tuple(
411+
-pad[-(i - prefix_len) * 2 - 2] if i >= prefix_len else 0
413412
for i in range(0, len(input_shape))
414413
)
415-
pre_stride = [-1] * len(input_shape)
414+
415+
shape = tuple(
416+
input_shape[i]
417+
+ (
418+
pad[-(i - prefix_len) * 2 - 1] + pad[-(i - prefix_len) * 2 - 2]
419+
if i >= prefix_len
420+
else 0
421+
)
422+
for i in range(0, len(input_shape))
423+
)
424+
stride = tuple([1] * len(shape))
416425

417426
layer = network.add_slice(
418427
input_val,
419-
pre_start,
420-
pre_shape,
421-
pre_stride,
428+
start,
429+
shape,
430+
stride,
422431
)
423-
layer.set_input(4, value_const)
424-
layer.mode = trt.SliceMode.FILL
425-
set_layer_name(layer, target, f"pre_{name}")
426-
half_pad_output = layer.get_output(0)
427432

428-
shape = half_pad_output.shape
429-
mid_start = tuple(i - 1 for i in shape)
430-
mid_stride = [-1] * len(shape)
431-
layer = network.add_slice(half_pad_output, mid_start, shape, mid_stride)
432433
layer.set_input(4, value_const)
433434
layer.mode = trt.SliceMode.FILL
434-
set_layer_name(layer, target, f"transpose_{name}")
435-
transpose_output = layer.get_output(0)
436-
437-
shape = transpose_output.shape
438-
post_start = tuple([0] * len(shape))
439-
post_shape = tuple(
440-
shape[i] + (pad[-(i - prefix_len) * 2 - 1] if i >= prefix_len else 0)
441-
for i in range(0, len(shape))
442-
)
443-
post_stride = tuple([1] * len(shape))
435+
set_layer_name(layer, target, name)
444436

445-
layer = network.add_slice(transpose_output, post_start, post_shape, post_stride)
446-
layer.set_input(4, value_const)
447-
layer.mode = trt.SliceMode.FILL
448-
set_layer_name(layer, target, f"post_{name}")
449437
return layer.get_output(0)
450438

451439

py/torch_tensorrt/fx/test/passes/test_fuse_permute_matmul_trt.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,20 @@ class TestFusePermuteMatmul(AccTestCase):
3737
lambda x: x.permute(0, 1, 3, 2),
3838
torch.matmul,
3939
),
40-
param("transpose_lhs_bmm_broadcast", (3, 2), (3, 3, 4), tranpose_last_two_dims, op=torch.matmul),
41-
param("transpose_rhs_bmm_broadcast", (3, 3, 4), (3, 4), rhs_op=tranpose_last_two_dims, op=torch.matmul),
40+
param(
41+
"transpose_lhs_bmm_broadcast",
42+
(3, 2),
43+
(3, 3, 4),
44+
tranpose_last_two_dims,
45+
op=torch.matmul,
46+
),
47+
param(
48+
"transpose_rhs_bmm_broadcast",
49+
(3, 3, 4),
50+
(3, 4),
51+
rhs_op=tranpose_last_two_dims,
52+
op=torch.matmul,
53+
),
4254
]
4355
)
4456
def test_fuse_permute_matmul(

0 commit comments

Comments
 (0)