Skip to content

Commit 1ed9d6d

Browse files
committed
feat: support tile dynamo converter
1 parent 59a4910 commit 1ed9d6d

File tree

3 files changed

+128
-3
lines changed

3 files changed

+128
-3
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -723,7 +723,30 @@ def aten_ops_cumsum(
723723
)
724724

725725

726-
@dynamo_tensorrt_converter(torch.ops.aten.permute.default)
726+
@dynamo_tensorrt_converter(torch.ops.aten.tile.default) # type: ignore[misc]
727+
@enforce_tensor_types(
728+
{
729+
0: (TRTTensor,),
730+
}
731+
) # type: ignore[misc]
732+
def aten_ops_tile(
733+
ctx: ConversionContext,
734+
target: Target,
735+
args: Tuple[Argument, ...],
736+
kwargs: Dict[str, Argument],
737+
name: str,
738+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
739+
return impl.slice.tile(
740+
ctx,
741+
target,
742+
SourceIR.ATEN,
743+
name,
744+
args[0],
745+
args[1],
746+
)
747+
748+
749+
@dynamo_tensorrt_converter(torch.ops.aten.permute.default) # type: ignore[misc]
727750
@enforce_tensor_types(
728751
{
729752
0: (TRTTensor,),

py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import math
2-
from typing import Optional
2+
from typing import Optional, Sequence
33

4-
import numpy as np
54
import tensorrt as trt
65
from torch.fx.node import Target
76
from torch_tensorrt.dynamo._SourceIR import SourceIR
@@ -203,3 +202,31 @@ def cumsum(
203202
set_layer_name(loop_output, target, f"{name}_loop_output", source_ir)
204203
loop_output.set_input(1, trip_limit)
205204
return loop_output.get_output(0)
205+
206+
207+
def tile(
208+
ctx: ConversionContext,
209+
target: Target,
210+
source_ir: Optional[SourceIR],
211+
name: str,
212+
input: TRTTensor,
213+
dims: Sequence[int],
214+
) -> TRTTensor:
215+
diff = len(dims) - len(input.shape)
216+
if diff > 0:
217+
# prepend 1 to input.shape
218+
new_shape = (1,) * diff + tuple(input.shape)
219+
input = impl.shuffle.reshape(
220+
ctx, target, source_ir, f"{name}_prepend_input_shape", input, new_shape
221+
)
222+
elif diff < 0:
223+
# prepend 1 to dims
224+
dims = (1,) * -diff + tuple(dims)
225+
226+
shapes = [i * j for i, j in zip(input.shape, dims)]
227+
starts = [0] * len(dims)
228+
strides = [1] * len(dims)
229+
layer = ctx.net.add_slice(input, tuple(starts), tuple(shapes), tuple(strides))
230+
layer.mode = trt.SampleMode.WRAP
231+
set_layer_name(layer, target, name)
232+
return layer.get_output(0)
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
import torch
2+
import torch.nn as nn
3+
from parameterized import parameterized
4+
from torch.testing._internal.common_utils import run_tests
5+
6+
from .harness import DispatchTestCase
7+
8+
9+
class TestTileConverter(DispatchTestCase):
10+
@parameterized.expand(
11+
[
12+
((3,), (1,)),
13+
((3,), (0,)),
14+
((3,), (2,)),
15+
((2,), (2, 2)),
16+
((2,), (0, 2)),
17+
]
18+
)
19+
def test_tile_1D(self, shape, dims):
20+
class Tile(nn.Module):
21+
def forward(self, x):
22+
return torch.ops.aten.tile.default(x, dims)
23+
24+
inputs = [torch.randn(shape)]
25+
self.run_test(
26+
Tile(),
27+
inputs,
28+
)
29+
30+
@parameterized.expand(
31+
[
32+
((3, 1), (0,)),
33+
((3, 1), (2,)),
34+
((2, 3), (2, 2)),
35+
((2, 3), (1, 0)),
36+
((2, 3), (0, 2)),
37+
((2, 3), (4, 2, 3)),
38+
((2, 3), (0, 0, 3)),
39+
((2, 3), (4, 2, 3, 1, 2)),
40+
]
41+
)
42+
def test_tile_2D(self, shape, dims):
43+
class Tile(nn.Module):
44+
def forward(self, x):
45+
return torch.ops.aten.tile.default(x, dims)
46+
47+
inputs = [torch.randn(shape)]
48+
self.run_test(
49+
Tile(),
50+
inputs,
51+
)
52+
53+
@parameterized.expand(
54+
[
55+
((4, 2, 3), (2,)),
56+
((4, 2, 3), (1, 2)),
57+
((1, 2, 3), (2, 3)),
58+
((1, 2, 3), (2, 3, 4)),
59+
((1, 2, 3), (2, 3, 4, 5)),
60+
]
61+
)
62+
def test_tile_3D(self, shape, dims):
63+
class Tile(nn.Module):
64+
def forward(self, x):
65+
return torch.ops.aten.tile.default(x, dims)
66+
67+
inputs = [torch.randn(shape)]
68+
self.run_test(
69+
Tile(),
70+
inputs,
71+
)
72+
73+
74+
if __name__ == "__main__":
75+
run_tests()

0 commit comments

Comments
 (0)