Skip to content

Commit 49ebcc5

Browse files
committed
feat: support circular padding dynamo converters for 1D, 2D, and 3D
1 parent 0c4cff0 commit 49ebcc5

File tree

3 files changed

+144
-0
lines changed

3 files changed

+144
-0
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2090,3 +2090,26 @@ def aten_ops_replication_pad(
20902090
args[0],
20912091
args[1],
20922092
)
2093+
2094+
2095+
@dynamo_tensorrt_converter(torch.ops.aten._pad_circular.default)
2096+
@enforce_tensor_types(
2097+
{
2098+
0: (TRTTensor,),
2099+
}
2100+
)
2101+
def aten_ops_circular_pad(
2102+
ctx: ConversionContext,
2103+
target: Target,
2104+
args: Tuple[Argument, ...],
2105+
kwargs: Dict[str, Argument],
2106+
name: str,
2107+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
2108+
return impl.pad.circular_padNd(
2109+
ctx,
2110+
target,
2111+
SourceIR.ATEN,
2112+
name,
2113+
args[0],
2114+
args[1],
2115+
)

py/torch_tensorrt/dynamo/conversion/impl/pad.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,3 +210,62 @@ def replication_padNd(
210210
raise RuntimeError(
211211
f"We currently only support for padding 1D, 2D, and 3D, but got {padding_dims}D"
212212
)
213+
214+
215+
def circular_padNd(
216+
ctx: ConversionContext,
217+
target: Union[Target, str],
218+
source_ir: Optional[SourceIR],
219+
name: str,
220+
input: TRTTensor,
221+
pad: Sequence[int],
222+
) -> TRTTensor:
223+
if has_dynamic_shape(input.shape):
224+
assert input.shape[1] != -1, "Channel dim can't be dynamic for padding."
225+
226+
padding_dims = len(pad) // 2
227+
228+
if padding_dims == 1 or padding_dims == 2 or padding_dims == 3:
229+
for i in range(padding_dims):
230+
dim = -1 - i
231+
pre_pad, post_pad = pad[2 * i], pad[2 * i + 1]
232+
pre_pad_tensor = impl.slice.slice_op(
233+
ctx,
234+
target,
235+
source_ir,
236+
f"{name}_slice_pre{i}",
237+
input,
238+
dim=dim,
239+
start=input.shape[dim] - pre_pad,
240+
stop=input.shape[dim],
241+
step=1,
242+
)
243+
244+
post_pad_tensor = impl.slice.slice_op(
245+
ctx,
246+
target,
247+
source_ir,
248+
f"{name}_slice_post{i}",
249+
input,
250+
dim=dim,
251+
start=0,
252+
stop=post_pad,
253+
step=1,
254+
)
255+
256+
output = impl.cat.cat(
257+
ctx,
258+
target,
259+
source_ir,
260+
f"{name}_concat_dim{dim}",
261+
input=(pre_pad_tensor, input, post_pad_tensor),
262+
dim=dim,
263+
)
264+
input = output
265+
266+
return output
267+
268+
else:
269+
raise RuntimeError(
270+
f"We currently only support for padding 1D, 2D, and 3D, but got {padding_dims}D"
271+
)

tests/py/dynamo/conversion/test_pad_aten.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,5 +154,67 @@ def forward(self, input):
154154
)
155155

156156

157+
class TestCircularPadConverter(DispatchTestCase):
158+
@parameterized.expand(
159+
[
160+
# Per pytorch doc, the input should be 2D or 3D
161+
((3, 3), (1, 1)),
162+
((3, 3), (2, 2)),
163+
((2, 2, 2), (1, 1)),
164+
((2, 2, 4), (2, 3)),
165+
]
166+
)
167+
def test_circular_pad1d(self, shape, pad):
168+
class TestModule(torch.nn.Module):
169+
def forward(self, input):
170+
return torch.ops.aten._pad_circular.default(input, pad)
171+
172+
input = [torch.randn(shape)]
173+
self.run_test(
174+
TestModule(),
175+
input,
176+
)
177+
178+
@parameterized.expand(
179+
[
180+
# Per pytorch doc, the input should be 3D or 4D
181+
((2, 2, 2), (1, 1, 1, 1)),
182+
((1, 2, 4), (2, 2, 1, 1)),
183+
((2, 2, 3, 3), (1, 1, 2, 2)),
184+
((2, 3, 4, 5), (4, 3, 0, 1)),
185+
]
186+
)
187+
def test_circular_pad2d(self, shape, pad):
188+
class TestModule(torch.nn.Module):
189+
def forward(self, input):
190+
return torch.ops.aten._pad_circular.default(input, pad)
191+
192+
input = [torch.randn(shape)]
193+
self.run_test(
194+
TestModule(),
195+
input,
196+
)
197+
198+
@parameterized.expand(
199+
[
200+
# Per pytorch doc, the input should be 4D or 5D
201+
((2, 2, 2, 2), (1, 1, 1, 1, 1, 1)),
202+
((1, 2, 3, 4), (3, 2, 2, 1, 1, 1)),
203+
((2, 2, 3, 3, 4), (3, 3, 2, 1, 1, 2)),
204+
((2, 3, 4, 5, 6), (4, 3, 2, 1, 1, 0)),
205+
]
206+
)
207+
def test_circular_pad3d(self, shape, pad):
208+
class TestModule(torch.nn.Module):
209+
def forward(self, input):
210+
return torch.ops.aten._pad_circular.default(input, pad)
211+
212+
input = [torch.randn(shape)]
213+
self.run_test(
214+
TestModule(),
215+
input,
216+
)
217+
218+
157219
if __name__ == "__main__":
158220
run_tests()

0 commit comments

Comments
 (0)