Skip to content

Commit 9b88e92

Browse files
authored
feat: support for many padding dynamo converters (#2482)
1 parent 4f8eb56 commit 9b88e92

File tree

5 files changed

+574
-6
lines changed

5 files changed

+574
-6
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2304,3 +2304,125 @@ def aten_ops_addmm(
23042304
beta=kwargs.get("beta", 1),
23052305
alpha=kwargs.get("alpha", 1),
23062306
)
2307+
2308+
2309+
@dynamo_tensorrt_converter(torch.ops.aten.constant_pad_nd.default)
2310+
@enforce_tensor_types(
2311+
{
2312+
0: (TRTTensor,),
2313+
}
2314+
)
2315+
def aten_ops_constant_pad(
2316+
ctx: ConversionContext,
2317+
target: Target,
2318+
args: Tuple[Argument, ...],
2319+
kwargs: Dict[str, Argument],
2320+
name: str,
2321+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
2322+
return impl.pad.constant_padNd(
2323+
ctx,
2324+
target,
2325+
SourceIR.ATEN,
2326+
name,
2327+
args[0],
2328+
args[1],
2329+
args_bounds_check(args, 2, 0),
2330+
)
2331+
2332+
2333+
@dynamo_tensorrt_converter(torch.ops.aten.reflection_pad1d.default)
2334+
@dynamo_tensorrt_converter(torch.ops.aten.reflection_pad2d.default)
2335+
@dynamo_tensorrt_converter(torch.ops.aten.reflection_pad3d.default)
2336+
@enforce_tensor_types(
2337+
{
2338+
0: (TRTTensor,),
2339+
}
2340+
)
2341+
def aten_ops_reflection_pad(
2342+
ctx: ConversionContext,
2343+
target: Target,
2344+
args: Tuple[Argument, ...],
2345+
kwargs: Dict[str, Argument],
2346+
name: str,
2347+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
2348+
return impl.pad.reflection_padNd(
2349+
ctx,
2350+
target,
2351+
SourceIR.ATEN,
2352+
name,
2353+
args[0],
2354+
args[1],
2355+
)
2356+
2357+
2358+
@dynamo_tensorrt_converter(torch.ops.aten.replication_pad1d.default)
2359+
@dynamo_tensorrt_converter(torch.ops.aten.replication_pad2d.default)
2360+
@dynamo_tensorrt_converter(torch.ops.aten.replication_pad3d.default)
2361+
@enforce_tensor_types(
2362+
{
2363+
0: (TRTTensor,),
2364+
}
2365+
)
2366+
def aten_ops_replication_pad(
2367+
ctx: ConversionContext,
2368+
target: Target,
2369+
args: Tuple[Argument, ...],
2370+
kwargs: Dict[str, Argument],
2371+
name: str,
2372+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
2373+
return impl.pad.replication_padNd(
2374+
ctx,
2375+
target,
2376+
SourceIR.ATEN,
2377+
name,
2378+
args[0],
2379+
args[1],
2380+
)
2381+
2382+
2383+
@dynamo_tensorrt_converter(torch.ops.aten._pad_circular.default)
2384+
@enforce_tensor_types(
2385+
{
2386+
0: (TRTTensor,),
2387+
}
2388+
)
2389+
def aten_ops_circular_pad(
2390+
ctx: ConversionContext,
2391+
target: Target,
2392+
args: Tuple[Argument, ...],
2393+
kwargs: Dict[str, Argument],
2394+
name: str,
2395+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
2396+
return impl.pad.circular_padNd(
2397+
ctx,
2398+
target,
2399+
SourceIR.ATEN,
2400+
name,
2401+
args[0],
2402+
args[1],
2403+
)
2404+
2405+
2406+
@dynamo_tensorrt_converter(torch.ops.aten.pad.default)
2407+
@enforce_tensor_types(
2408+
{
2409+
0: (TRTTensor,),
2410+
}
2411+
)
2412+
def aten_ops_pad(
2413+
ctx: ConversionContext,
2414+
target: Target,
2415+
args: Tuple[Argument, ...],
2416+
kwargs: Dict[str, Argument],
2417+
name: str,
2418+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
2419+
return impl.pad.pad(
2420+
ctx,
2421+
target,
2422+
SourceIR.ATEN,
2423+
name,
2424+
args[0],
2425+
pad=args[1],
2426+
mode=args_bounds_check(args, 2, "constant"),
2427+
value=args_bounds_check(args, 3, None),
2428+
)

py/torch_tensorrt/dynamo/conversion/impl/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
linear,
1717
matmul,
1818
normalization,
19+
pad,
1920
permutation,
2021
pool,
2122
reduce,
Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,16 @@
1-
from typing import Dict, Optional, Sequence, Union
1+
from typing import Optional, Sequence, Union
22

33
import numpy as np
44
import torch
55
from torch.fx.node import Target
66
from torch_tensorrt.dynamo._SourceIR import SourceIR
77
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
88
from torch_tensorrt.dynamo.conversion.converter_utils import (
9-
SourceIR,
109
get_positive_dim,
1110
get_trt_tensor,
1211
)
1312
from torch_tensorrt.fx.converters.converter_utils import set_layer_name
14-
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor
13+
from torch_tensorrt.fx.types import TRTTensor
1514

1615

1716
def cat(
@@ -23,12 +22,12 @@ def cat(
2322
dim: int,
2423
) -> Union[TRTTensor, Sequence[TRTTensor]]:
2524
trt_inputs = []
26-
for each_input in input:
25+
for i, each_input in enumerate(input):
2726
if not isinstance(each_input, TRTTensor):
28-
each_input = get_trt_tensor(ctx, each_input, name + "_tensor_{i}")
27+
each_input = get_trt_tensor(ctx, each_input, f"{name}_tensor_{i}")
2928
trt_inputs.append(each_input)
3029
concat_layer = ctx.net.add_concatenation(trt_inputs)
3130
dim = get_positive_dim(dim, len(input[0].shape))
3231
concat_layer.axis = dim
33-
set_layer_name(concat_layer, target, name + "_gather", source_ir)
32+
set_layer_name(concat_layer, target, f"{name}_gather", source_ir)
3433
return concat_layer.get_output(0)
Lines changed: 205 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,205 @@
1+
from typing import Optional, Sequence, Union
2+
3+
import tensorrt as trt
4+
from torch.fx.node import Target
5+
from torch_tensorrt.dynamo._SourceIR import SourceIR
6+
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
7+
from torch_tensorrt.dynamo.conversion.converter_utils import get_trt_tensor
8+
from torch_tensorrt.fx.converters.converter_utils import (
9+
has_dynamic_shape,
10+
set_layer_name,
11+
)
12+
from torch_tensorrt.fx.types import TRTTensor
13+
14+
"""
15+
Note: IPaddingLayer is deprecated in TensorRT 8.2 and will be removed in TensorRT 10.0.
16+
Use ISliceLayer to pad the tensor, which supports new non-constant, reflects padding
17+
mode and clamp, and supports padding output with dynamic shape.
18+
"""
19+
20+
21+
def constant_padNd(
22+
ctx: ConversionContext,
23+
target: Union[Target, str],
24+
source_ir: Optional[SourceIR],
25+
name: str,
26+
input: TRTTensor,
27+
pad: Sequence[int],
28+
value: Union[int, float] = 0,
29+
) -> TRTTensor:
30+
if has_dynamic_shape(input.shape):
31+
assert input.shape[1] != -1, "Channel dim can't be dynamic for padding."
32+
33+
rank = len(input.shape)
34+
35+
if len(pad) // 2 > rank:
36+
raise RuntimeError(
37+
f"Trying to pad last {len(pad) // 2} dimension but the input only has {rank} dimension."
38+
)
39+
40+
start_list = [0] * rank
41+
new_shape = list(input.shape)
42+
43+
for i in range(0, len(pad) // 2):
44+
start_list[-i - 1] = -pad[i * 2]
45+
new_shape[-i - 1] += pad[i * 2] + pad[i * 2 + 1]
46+
47+
stride_list = [1] * rank
48+
layer = ctx.net.add_slice(
49+
input,
50+
start=tuple(start_list),
51+
shape=tuple(new_shape),
52+
stride=tuple(stride_list),
53+
)
54+
value_const = get_trt_tensor(ctx, value, f"{name}_value", input.dtype)
55+
layer.set_input(4, value_const)
56+
layer.mode = trt.SliceMode.FILL
57+
58+
set_layer_name(layer, target, name, source_ir)
59+
return layer.get_output(0)
60+
61+
62+
def reflection_padNd(
63+
ctx: ConversionContext,
64+
target: Union[Target, str],
65+
source_ir: Optional[SourceIR],
66+
name: str,
67+
input: TRTTensor,
68+
padding: Sequence[int],
69+
) -> TRTTensor:
70+
if has_dynamic_shape(input.shape):
71+
assert input.shape[1] != -1, "Channel dim can't be dynamic for padding."
72+
73+
rank = len(input.shape)
74+
75+
if len(padding) // 2 > rank:
76+
raise RuntimeError(
77+
f"Trying to pad last {len(padding) // 2} dimension but the input only has {rank} dimension."
78+
)
79+
80+
start_list = [0] * rank
81+
new_shape = list(input.shape)
82+
83+
for i in range(0, len(padding) // 2):
84+
start_list[-i - 1] = -padding[i * 2]
85+
new_shape[-i - 1] += padding[i * 2] + padding[i * 2 + 1]
86+
87+
stride_list = [1] * rank
88+
layer = ctx.net.add_slice(
89+
input,
90+
start=tuple(start_list),
91+
shape=tuple(new_shape),
92+
stride=tuple(stride_list),
93+
)
94+
layer.mode = trt.SliceMode.REFLECT
95+
96+
set_layer_name(layer, target, name, source_ir)
97+
return layer.get_output(0)
98+
99+
100+
def replication_padNd(
101+
ctx: ConversionContext,
102+
target: Union[Target, str],
103+
source_ir: Optional[SourceIR],
104+
name: str,
105+
input: TRTTensor,
106+
padding: Sequence[int],
107+
) -> TRTTensor:
108+
if has_dynamic_shape(input.shape):
109+
assert input.shape[1] != -1, "Channel dim can't be dynamic for padding."
110+
111+
rank = len(input.shape)
112+
113+
if len(padding) // 2 > rank:
114+
raise RuntimeError(
115+
f"Trying to pad last {len(padding) // 2} dimension but the input only has {rank} dimension."
116+
)
117+
118+
start_list = [0] * rank
119+
new_shape = list(input.shape)
120+
121+
for i in range(0, len(padding) // 2):
122+
start_list[-i - 1] = -padding[i * 2]
123+
new_shape[-i - 1] += padding[i * 2] + padding[i * 2 + 1]
124+
125+
stride_list = [1] * rank
126+
layer = ctx.net.add_slice(
127+
input,
128+
start=tuple(start_list),
129+
shape=tuple(new_shape),
130+
stride=tuple(stride_list),
131+
)
132+
layer.mode = trt.SliceMode.CLAMP
133+
134+
set_layer_name(layer, target, name, source_ir)
135+
return layer.get_output(0)
136+
137+
138+
def circular_padNd(
139+
ctx: ConversionContext,
140+
target: Union[Target, str],
141+
source_ir: Optional[SourceIR],
142+
name: str,
143+
input: TRTTensor,
144+
pad: Sequence[int],
145+
) -> TRTTensor:
146+
if has_dynamic_shape(input.shape):
147+
assert input.shape[1] != -1, "Channel dim can't be dynamic for padding."
148+
149+
rank = len(input.shape)
150+
151+
if len(pad) // 2 > rank:
152+
raise RuntimeError(
153+
f"Trying to pad last {len(pad) // 2} dimension but the input only has {rank} dimension."
154+
)
155+
156+
start_list = [0] * rank
157+
new_shape = list(input.shape)
158+
159+
for i in range(0, len(pad) // 2):
160+
start_list[-i - 1] = -pad[i * 2]
161+
new_shape[-i - 1] += pad[i * 2] + pad[i * 2 + 1]
162+
163+
stride_list = [1] * rank
164+
layer = ctx.net.add_slice(
165+
input,
166+
start=tuple(start_list),
167+
shape=tuple(new_shape),
168+
stride=tuple(stride_list),
169+
)
170+
layer.mode = trt.SliceMode.WRAP
171+
172+
set_layer_name(layer, target, name, source_ir)
173+
return layer.get_output(0)
174+
175+
176+
def pad(
177+
ctx: ConversionContext,
178+
target: Union[Target, str],
179+
source_ir: Optional[SourceIR],
180+
name: str,
181+
input: TRTTensor,
182+
pad: Sequence[int],
183+
mode: str = "constant",
184+
value: Optional[float] = None,
185+
) -> TRTTensor:
186+
if mode == "constant":
187+
return constant_padNd(
188+
ctx,
189+
target,
190+
source_ir,
191+
f"{name}_{mode}",
192+
input,
193+
pad,
194+
value if value is not None else 0,
195+
)
196+
elif mode == "reflect":
197+
return reflection_padNd(ctx, target, source_ir, f"{name}_{mode}", input, pad)
198+
elif mode == "replicate":
199+
return replication_padNd(ctx, target, source_ir, f"{name}_{mode}", input, pad)
200+
elif mode == "circular":
201+
return circular_padNd(ctx, target, source_ir, f"{name}_{mode}", input, pad)
202+
else:
203+
raise RuntimeError(
204+
f'We currently only support for `mode` in ["constant", "reflect", "replicate", "circular"], but got {mode}'
205+
)

0 commit comments

Comments
 (0)