Skip to content

Commit 85e28ac

Browse files
committed
feat: support pad dynamo converter
1 parent 49ebcc5 commit 85e28ac

File tree

3 files changed

+79
-1
lines changed

3 files changed

+79
-1
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2113,3 +2113,28 @@ def aten_ops_circular_pad(
21132113
args[0],
21142114
args[1],
21152115
)
2116+
2117+
2118+
@dynamo_tensorrt_converter(torch.ops.aten.pad.default)
2119+
@enforce_tensor_types(
2120+
{
2121+
0: (TRTTensor,),
2122+
}
2123+
)
2124+
def aten_ops_pad(
2125+
ctx: ConversionContext,
2126+
target: Target,
2127+
args: Tuple[Argument, ...],
2128+
kwargs: Dict[str, Argument],
2129+
name: str,
2130+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
2131+
return impl.pad.pad(
2132+
ctx,
2133+
target,
2134+
SourceIR.ATEN,
2135+
name,
2136+
args[0],
2137+
pad=args[1],
2138+
mode=args_bounds_check(args, 2, "constant"),
2139+
value=args_bounds_check(args, 3, None),
2140+
)

py/torch_tensorrt/dynamo/conversion/impl/pad.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def constant_padNd(
2020
name: str,
2121
input: TRTTensor,
2222
pad: Sequence[int],
23-
value: int = 0,
23+
value: Union[int, float] = 0,
2424
) -> TRTTensor:
2525
if has_dynamic_shape(input.shape):
2626
assert input.shape[1] != -1, "Channel dim can't be dynamic for padding."
@@ -269,3 +269,35 @@ def circular_padNd(
269269
raise RuntimeError(
270270
f"We currently only support for padding 1D, 2D, and 3D, but got {padding_dims}D"
271271
)
272+
273+
274+
def pad(
275+
ctx: ConversionContext,
276+
target: Union[Target, str],
277+
source_ir: Optional[SourceIR],
278+
name: str,
279+
input: TRTTensor,
280+
pad: Sequence[int],
281+
mode: str = "constant",
282+
value: Optional[float] = None,
283+
) -> TRTTensor:
284+
if mode == "constant":
285+
return constant_padNd(
286+
ctx,
287+
target,
288+
source_ir,
289+
f"{name}_{mode}",
290+
input,
291+
pad,
292+
value if value is not None else 0,
293+
)
294+
elif mode == "reflect":
295+
return reflection_padNd(ctx, target, source_ir, f"{name}_{mode}", input, pad)
296+
elif mode == "replicate":
297+
return replication_padNd(ctx, target, source_ir, f"{name}_{mode}", input, pad)
298+
elif mode == "circular":
299+
return circular_padNd(ctx, target, source_ir, f"{name}_{mode}", input, pad)
300+
else:
301+
raise RuntimeError(
302+
f'We currently only support for `mode` in ["constant", "reflect", "replicate", "circular"], but got {mode}'
303+
)

tests/py/dynamo/conversion/test_pad_aten.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,5 +216,26 @@ def forward(self, input):
216216
)
217217

218218

219+
class TestPadConverter(DispatchTestCase):
220+
@parameterized.expand(
221+
[
222+
((3, 3), (2, 2), "constant"),
223+
((2, 2, 4), (2, 3, 1, 0), "reflect"),
224+
((1, 2, 3, 4), (3, 2, 2, 1, 1, 1), "replicate"),
225+
((2, 3, 4, 5), (3, 2, 1, 0), "circular"),
226+
]
227+
)
228+
def test_pad(self, shape, pad, mode, value=None):
229+
class TestModule(torch.nn.Module):
230+
def forward(self, input):
231+
return torch.ops.aten.pad.default(input, pad, mode, value)
232+
233+
input = [torch.randn(shape)]
234+
self.run_test(
235+
TestModule(),
236+
input,
237+
)
238+
239+
219240
if __name__ == "__main__":
220241
run_tests()

0 commit comments

Comments
 (0)