Skip to content

Commit 7534fa2

Browse files
committed
implement paddings via TRT ISliceLayer with different SliceMode
1 parent d908601 commit 7534fa2

File tree

2 files changed

+107
-183
lines changed

2 files changed

+107
-183
lines changed

py/torch_tensorrt/dynamo/conversion/impl/pad.py

Lines changed: 103 additions & 183 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,22 @@
1-
import copy
21
from typing import Optional, Sequence, Union
32

4-
import torch
5-
import torch_tensorrt.dynamo.conversion.impl as impl
3+
import tensorrt as trt
64
from torch.fx.node import Target
75
from torch_tensorrt.dynamo._SourceIR import SourceIR
86
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
9-
from torch_tensorrt.fx.converters.converter_utils import has_dynamic_shape
7+
from torch_tensorrt.fx.converters.converter_utils import (
8+
get_trt_tensor,
9+
has_dynamic_shape,
10+
set_layer_name,
11+
)
1012
from torch_tensorrt.fx.types import TRTTensor
1113

14+
"""
15+
Note: IPaddingLayer is deprecated in TensorRT 8.2 and will be removed in TensorRT 10.0.
16+
Use ISliceLayer to pad the tensor, which supports new non-constant, reflects padding
17+
mode and clamp, and supports padding output with dynamic shape.
18+
"""
19+
1220

1321
def constant_padNd(
1422
ctx: ConversionContext,
@@ -19,43 +27,36 @@ def constant_padNd(
1927
pad: Sequence[int],
2028
value: Union[int, float] = 0,
2129
) -> TRTTensor:
22-
"""
23-
Note: IPaddingLayer is deprecated in TensorRT 8.2 and will be removed in TensorRT 10.0.
24-
Use ISliceLayer to pad the tensor, which supports new non-constant, reflects padding
25-
mode and clamp, and supports padding output with dynamic shape.
26-
"""
2730
if has_dynamic_shape(input.shape):
2831
assert input.shape[1] != -1, "Channel dim can't be dynamic for padding."
2932

30-
# Implement constant padding via concat
31-
curr_dim = len(input.shape) - 1
32-
33-
for i in range(0, len(pad), 2):
34-
input_shape = list(input.shape)
35-
36-
pre_pad = pad[i]
37-
post_pad = pad[i + 1]
38-
pre_pad_shape = copy.deepcopy(input_shape)
39-
pre_pad_shape[curr_dim] = pre_pad
40-
pre_pad_tensor = torch.full(pre_pad_shape, float(value))
41-
if pre_pad == post_pad:
42-
post_pad_tensor = pre_pad_tensor
43-
else:
44-
post_pad_shape = copy.deepcopy(input_shape)
45-
post_pad_shape[curr_dim] = post_pad
46-
post_pad_tensor = torch.full(post_pad_shape, float(value))
47-
output = impl.cat.cat(
48-
ctx,
49-
target,
50-
source_ir,
51-
f"{name}_concat{curr_dim}",
52-
input=(pre_pad_tensor, input, post_pad_tensor),
53-
dim=curr_dim,
33+
rank = len(input.shape)
34+
35+
if len(pad) / 2 > rank:
36+
raise RuntimeError(
37+
f"Trying to pad last {len(pad) / 2} dimension but the input only has {rank} dimension."
5438
)
55-
curr_dim -= 1
56-
input = output
5739

58-
return output
40+
start_list = [0] * len(input.shape)
41+
new_shape = input.shape
42+
43+
for i in range(0, len(pad) // 2):
44+
start_list[-i - 1] = -pad[i * 2]
45+
new_shape[-i - 1] += pad[i * 2] + pad[i * 2 + 1]
46+
47+
stride_list = [1] * len(new_shape)
48+
layer = ctx.net.add_slice(
49+
input,
50+
start=tuple(start_list),
51+
shape=tuple(new_shape),
52+
stride=tuple(stride_list),
53+
)
54+
value_const = get_trt_tensor(ctx.net, value, f"{name}_value", input.dtype)
55+
layer.set_input(4, value_const)
56+
layer.mode = trt.SliceMode.FILL
57+
58+
set_layer_name(layer, target, name, source_ir)
59+
return layer.get_output(0)
5960

6061

6162
def reflection_padNd(
@@ -69,53 +70,32 @@ def reflection_padNd(
6970
if has_dynamic_shape(input.shape):
7071
assert input.shape[1] != -1, "Channel dim can't be dynamic for padding."
7172

72-
padding_dims = len(padding) // 2
73-
74-
if padding_dims == 1 or padding_dims == 2 or padding_dims == 3:
75-
for i in range(padding_dims):
76-
dim = -1 - i
77-
pre_pad, post_pad = padding[2 * i], padding[2 * i + 1]
78-
pre_pad_tensor = impl.slice.slice_op(
79-
ctx,
80-
target,
81-
source_ir,
82-
f"{name}_slice_pre{i}",
83-
input,
84-
dim=dim,
85-
start=pre_pad,
86-
stop=0,
87-
step=-1,
88-
)
89-
90-
post_pad_tensor = impl.slice.slice_op(
91-
ctx,
92-
target,
93-
source_ir,
94-
f"{name}_slice_post{i}",
95-
input,
96-
dim=dim,
97-
start=input.shape[dim] - 2,
98-
stop=input.shape[dim] - post_pad - 2,
99-
step=-1,
100-
)
101-
102-
output = impl.cat.cat(
103-
ctx,
104-
target,
105-
source_ir,
106-
f"{name}_concat_dim{dim}",
107-
input=(pre_pad_tensor, input, post_pad_tensor),
108-
dim=dim,
109-
)
110-
input = output
111-
112-
return output
73+
rank = len(input.shape)
11374

114-
else:
75+
if len(padding) / 2 > rank:
11576
raise RuntimeError(
116-
f"We currently only support for padding 1D, 2D, and 3D, but got {padding_dims}D"
77+
f"Trying to pad last {len(padding) / 2} dimension but the input only has {rank} dimension."
11778
)
11879

80+
start_list = [0] * len(input.shape)
81+
new_shape = input.shape
82+
83+
for i in range(0, len(padding) // 2):
84+
start_list[-i - 1] = -padding[i * 2]
85+
new_shape[-i - 1] += padding[i * 2] + padding[i * 2 + 1]
86+
87+
stride_list = [1] * len(new_shape)
88+
layer = ctx.net.add_slice(
89+
input,
90+
start=tuple(start_list),
91+
shape=tuple(new_shape),
92+
stride=tuple(stride_list),
93+
)
94+
layer.mode = trt.SliceMode.REFLECT
95+
96+
set_layer_name(layer, target, name, source_ir)
97+
return layer.get_output(0)
98+
11999

120100
def replication_padNd(
121101
ctx: ConversionContext,
@@ -128,71 +108,32 @@ def replication_padNd(
128108
if has_dynamic_shape(input.shape):
129109
assert input.shape[1] != -1, "Channel dim can't be dynamic for padding."
130110

131-
padding_dims = len(padding) // 2
132-
133-
if padding_dims == 1 or padding_dims == 2 or padding_dims == 3:
134-
for i in range(padding_dims):
135-
dim = -1 - i
136-
pre_pad, post_pad = padding[2 * i], padding[2 * i + 1]
137-
pre_pad_tensor = impl.slice.slice_op(
138-
ctx,
139-
target,
140-
source_ir,
141-
f"{name}_slice_pre{i}",
142-
input,
143-
dim=dim,
144-
start=0,
145-
stop=1,
146-
step=1,
147-
)
148-
new_shape = input.shape
149-
new_shape[dim] = pre_pad
150-
pre_pad_tensor = impl.slice.expand(
151-
ctx,
152-
target,
153-
source_ir,
154-
f"{name}_expand_pre{i}",
155-
pre_pad_tensor,
156-
new_shape,
157-
)
158-
159-
post_pad_tensor = impl.slice.slice_op(
160-
ctx,
161-
target,
162-
source_ir,
163-
f"{name}_slice_post{i}",
164-
input,
165-
dim=dim,
166-
start=input.shape[dim] - 1,
167-
stop=input.shape[dim],
168-
step=1,
169-
)
170-
new_shape[dim] = post_pad
171-
post_pad_tensor = impl.slice.expand(
172-
ctx,
173-
target,
174-
source_ir,
175-
f"{name}_expand_post{i}",
176-
post_pad_tensor,
177-
new_shape,
178-
)
179-
output = impl.cat.cat(
180-
ctx,
181-
target,
182-
source_ir,
183-
f"{name}_concat_dim{dim}",
184-
input=(pre_pad_tensor, input, post_pad_tensor),
185-
dim=dim,
186-
)
187-
input = output
188-
189-
return output
111+
rank = len(input.shape)
190112

191-
else:
113+
if len(padding) / 2 > rank:
192114
raise RuntimeError(
193-
f"We currently only support for padding 1D, 2D, and 3D, but got {padding_dims}D"
115+
f"Trying to pad last {len(padding) / 2} dimension but the input only has {rank} dimension."
194116
)
195117

118+
start_list = [0] * len(input.shape)
119+
new_shape = input.shape
120+
121+
for i in range(0, len(padding) // 2):
122+
start_list[-i - 1] = -padding[i * 2]
123+
new_shape[-i - 1] += padding[i * 2] + padding[i * 2 + 1]
124+
125+
stride_list = [1] * len(new_shape)
126+
layer = ctx.net.add_slice(
127+
input,
128+
start=tuple(start_list),
129+
shape=tuple(new_shape),
130+
stride=tuple(stride_list),
131+
)
132+
layer.mode = trt.SliceMode.CLAMP
133+
134+
set_layer_name(layer, target, name, source_ir)
135+
return layer.get_output(0)
136+
196137

197138
def circular_padNd(
198139
ctx: ConversionContext,
@@ -205,53 +146,32 @@ def circular_padNd(
205146
if has_dynamic_shape(input.shape):
206147
assert input.shape[1] != -1, "Channel dim can't be dynamic for padding."
207148

208-
padding_dims = len(pad) // 2
209-
210-
if padding_dims == 1 or padding_dims == 2 or padding_dims == 3:
211-
for i in range(padding_dims):
212-
dim = -1 - i
213-
pre_pad, post_pad = pad[2 * i], pad[2 * i + 1]
214-
pre_pad_tensor = impl.slice.slice_op(
215-
ctx,
216-
target,
217-
source_ir,
218-
f"{name}_slice_pre{i}",
219-
input,
220-
dim=dim,
221-
start=input.shape[dim] - pre_pad,
222-
stop=input.shape[dim],
223-
step=1,
224-
)
225-
226-
post_pad_tensor = impl.slice.slice_op(
227-
ctx,
228-
target,
229-
source_ir,
230-
f"{name}_slice_post{i}",
231-
input,
232-
dim=dim,
233-
start=0,
234-
stop=post_pad,
235-
step=1,
236-
)
237-
238-
output = impl.cat.cat(
239-
ctx,
240-
target,
241-
source_ir,
242-
f"{name}_concat_dim{dim}",
243-
input=(pre_pad_tensor, input, post_pad_tensor),
244-
dim=dim,
245-
)
246-
input = output
247-
248-
return output
149+
rank = len(input.shape)
249150

250-
else:
151+
if len(pad) / 2 > rank:
251152
raise RuntimeError(
252-
f"We currently only support for padding 1D, 2D, and 3D, but got {padding_dims}D"
153+
f"Trying to pad last {len(pad) / 2} dimension but the input only has {rank} dimension."
253154
)
254155

156+
start_list = [0] * len(input.shape)
157+
new_shape = input.shape
158+
159+
for i in range(0, len(pad) // 2):
160+
start_list[-i - 1] = -pad[i * 2]
161+
new_shape[-i - 1] += pad[i * 2] + pad[i * 2 + 1]
162+
163+
stride_list = [1] * len(new_shape)
164+
layer = ctx.net.add_slice(
165+
input,
166+
start=tuple(start_list),
167+
shape=tuple(new_shape),
168+
stride=tuple(stride_list),
169+
)
170+
layer.mode = trt.SliceMode.WRAP
171+
172+
set_layer_name(layer, target, name, source_ir)
173+
return layer.get_output(0)
174+
255175

256176
def pad(
257177
ctx: ConversionContext,

py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,10 @@ def slice_op( # TODO: This should be slice not whatever is in base
3939
if stop is None:
4040
stop = input.shape[dim]
4141

42+
dim = get_positive_dim(dim, len(input.shape))
43+
start = get_positive_dim(start, input.shape[dim])
44+
stop = get_positive_dim(stop, input.shape[dim])
45+
4246
if has_dynamic_shape(input.shape):
4347
# Check whether slice target dim is dynamic shape dim
4448
assert input.shape[dim] != -1, "Can't slice on dynamic shape dimension!"

0 commit comments

Comments
 (0)