Skip to content

Commit 5578763

Browse files
authored
feat: support tile dynamo converter (#2402)
1 parent 59a4910 commit 5578763

File tree

3 files changed

+128
-2
lines changed

3 files changed

+128
-2
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -723,7 +723,30 @@ def aten_ops_cumsum(
723723
)
724724

725725

726-
@dynamo_tensorrt_converter(torch.ops.aten.permute.default)
726+
@dynamo_tensorrt_converter(torch.ops.aten.tile.default) # type: ignore[misc]
727+
@enforce_tensor_types(
728+
{
729+
0: (TRTTensor,),
730+
}
731+
) # type: ignore[misc]
732+
def aten_ops_tile(
733+
ctx: ConversionContext,
734+
target: Target,
735+
args: Tuple[Argument, ...],
736+
kwargs: Dict[str, Argument],
737+
name: str,
738+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
739+
return impl.slice.tile(
740+
ctx,
741+
target,
742+
SourceIR.ATEN,
743+
name,
744+
args[0],
745+
args[1],
746+
)
747+
748+
749+
@dynamo_tensorrt_converter(torch.ops.aten.permute.default) # type: ignore[misc]
727750
@enforce_tensor_types(
728751
{
729752
0: (TRTTensor,),

py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import math
2-
from typing import Optional
2+
from typing import Optional, Sequence
33

44
import numpy as np
55
import tensorrt as trt
@@ -203,3 +203,31 @@ def cumsum(
203203
set_layer_name(loop_output, target, f"{name}_loop_output", source_ir)
204204
loop_output.set_input(1, trip_limit)
205205
return loop_output.get_output(0)
206+
207+
208+
def tile(
209+
ctx: ConversionContext,
210+
target: Target,
211+
source_ir: Optional[SourceIR],
212+
name: str,
213+
input: TRTTensor,
214+
dims: Sequence[int],
215+
) -> TRTTensor:
216+
diff = len(dims) - len(input.shape)
217+
if diff > 0:
218+
# prepend 1 to input.shape
219+
new_shape = (1,) * diff + tuple(input.shape)
220+
input = impl.shuffle.reshape(
221+
ctx, target, source_ir, f"{name}_prepend_input_shape", input, new_shape
222+
)
223+
elif diff < 0:
224+
# prepend 1 to dims
225+
dims = (1,) * -diff + tuple(dims)
226+
227+
shapes = [i * j for i, j in zip(input.shape, dims)]
228+
starts = [0] * len(dims)
229+
strides = [1] * len(dims)
230+
layer = ctx.net.add_slice(input, tuple(starts), tuple(shapes), tuple(strides))
231+
layer.mode = trt.SampleMode.WRAP
232+
set_layer_name(layer, target, name)
233+
return layer.get_output(0)
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
import torch
2+
import torch.nn as nn
3+
from parameterized import parameterized
4+
from torch.testing._internal.common_utils import run_tests
5+
6+
from .harness import DispatchTestCase
7+
8+
9+
class TestTileConverter(DispatchTestCase):
10+
@parameterized.expand(
11+
[
12+
((3,), (1,)),
13+
((3,), (0,)),
14+
((3,), (2,)),
15+
((2,), (2, 2)),
16+
((2,), (0, 2)),
17+
]
18+
)
19+
def test_tile_1D(self, shape, dims):
20+
class Tile(nn.Module):
21+
def forward(self, x):
22+
return torch.ops.aten.tile.default(x, dims)
23+
24+
inputs = [torch.randn(shape)]
25+
self.run_test(
26+
Tile(),
27+
inputs,
28+
)
29+
30+
@parameterized.expand(
31+
[
32+
((3, 1), (0,)),
33+
((3, 1), (2,)),
34+
((2, 3), (2, 2)),
35+
((2, 3), (1, 0)),
36+
((2, 3), (0, 2)),
37+
((2, 3), (4, 2, 3)),
38+
((2, 3), (0, 0, 3)),
39+
((2, 3), (4, 2, 3, 1, 2)),
40+
]
41+
)
42+
def test_tile_2D(self, shape, dims):
43+
class Tile(nn.Module):
44+
def forward(self, x):
45+
return torch.ops.aten.tile.default(x, dims)
46+
47+
inputs = [torch.randn(shape)]
48+
self.run_test(
49+
Tile(),
50+
inputs,
51+
)
52+
53+
@parameterized.expand(
54+
[
55+
((4, 2, 3), (2,)),
56+
((4, 2, 3), (1, 2)),
57+
((1, 2, 3), (2, 3)),
58+
((1, 2, 3), (2, 3, 4)),
59+
((1, 2, 3), (2, 3, 4, 5)),
60+
]
61+
)
62+
def test_tile_3D(self, shape, dims):
63+
class Tile(nn.Module):
64+
def forward(self, x):
65+
return torch.ops.aten.tile.default(x, dims)
66+
67+
inputs = [torch.randn(shape)]
68+
self.run_test(
69+
Tile(),
70+
inputs,
71+
)
72+
73+
74+
if __name__ == "__main__":
75+
run_tests()

0 commit comments

Comments
 (0)