Skip to content

Commit 9fd1157

Browse files
committed
feat: support tile dynamo converter
1 parent 6ab74fe commit 9fd1157

File tree

3 files changed

+127
-2
lines changed

3 files changed

+127
-2
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -714,6 +714,29 @@ def aten_ops_cumsum(
714714
)
715715

716716

717+
@dynamo_tensorrt_converter(torch.ops.aten.tile.default) # type: ignore[misc]
718+
@enforce_tensor_types(
719+
{
720+
0: (TRTTensor,),
721+
}
722+
) # type: ignore[misc]
723+
def aten_ops_tile(
724+
ctx: ConversionContext,
725+
target: Target,
726+
args: Tuple[Argument, ...],
727+
kwargs: Dict[str, Argument],
728+
name: str,
729+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
730+
return impl.slice.tile(
731+
ctx,
732+
target,
733+
SourceIR.ATEN,
734+
name,
735+
args[0],
736+
args[1],
737+
)
738+
739+
717740
@dynamo_tensorrt_converter(torch.ops.aten.permute.default) # type: ignore[misc]
718741
@enforce_tensor_types(
719742
{

py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import math
2-
from typing import Optional
2+
from typing import Optional, Sequence
33

4-
import numpy as np
54
import tensorrt as trt
65
from torch.fx.node import Target
76
from torch_tensorrt.dynamo._SourceIR import SourceIR
@@ -203,3 +202,31 @@ def cumsum(
203202
set_layer_name(loop_output, target, f"{name}_loop_output", source_ir)
204203
loop_output.set_input(1, trip_limit)
205204
return loop_output.get_output(0)
205+
206+
207+
def tile(
208+
ctx: ConversionContext,
209+
target: Target,
210+
source_ir: Optional[SourceIR],
211+
name: str,
212+
input: TRTTensor,
213+
dims: Sequence[int],
214+
) -> TRTTensor:
215+
diff = len(dims) - len(input.shape)
216+
if diff > 0:
217+
# prepend 1 to input.shape
218+
new_shape = (1,) * diff + tuple(input.shape)
219+
input = impl.shuffle.reshape(
220+
ctx, target, source_ir, f"{name}_prepend_input_shape", input, new_shape
221+
)
222+
elif diff < 0:
223+
# prepend 1 to dims
224+
dims = (1,) * -diff + tuple(dims)
225+
226+
shapes = [i * j for i, j in zip(input.shape, dims)]
227+
starts = [0] * len(dims)
228+
strides = [1] * len(dims)
229+
layer = ctx.net.add_slice(input, tuple(starts), tuple(shapes), tuple(strides))
230+
layer.mode = trt.SampleMode.WRAP
231+
set_layer_name(layer, target, name)
232+
return layer.get_output(0)
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
import torch
2+
import torch.nn as nn
3+
from parameterized import parameterized
4+
from torch.testing._internal.common_utils import run_tests
5+
6+
from .harness import DispatchTestCase
7+
8+
9+
class TestTileConverter(DispatchTestCase):
10+
@parameterized.expand(
11+
[
12+
((3,), (1,)),
13+
((3,), (0,)),
14+
((3,), (2,)),
15+
((2,), (2, 2)),
16+
((2,), (0, 2)),
17+
]
18+
)
19+
def test_tile_1D(self, shape, dims):
20+
class Tile(nn.Module):
21+
def forward(self, x):
22+
return torch.ops.aten.tile.default(x, dims)
23+
24+
inputs = [torch.randn(shape)]
25+
self.run_test(
26+
Tile(),
27+
inputs,
28+
)
29+
30+
@parameterized.expand(
31+
[
32+
((3, 1), (0,)),
33+
((3, 1), (2,)),
34+
((2, 3), (2, 2)),
35+
((2, 3), (1, 0)),
36+
((2, 3), (0, 2)),
37+
((2, 3), (4, 2, 3)),
38+
((2, 3), (0, 0, 3)),
39+
((2, 3), (4, 2, 3, 1, 2)),
40+
]
41+
)
42+
def test_tile_2D(self, shape, dims):
43+
class Tile(nn.Module):
44+
def forward(self, x):
45+
return torch.ops.aten.tile.default(x, dims)
46+
47+
inputs = [torch.randn(shape)]
48+
self.run_test(
49+
Tile(),
50+
inputs,
51+
)
52+
53+
@parameterized.expand(
54+
[
55+
((4, 2, 3), (2,)),
56+
((4, 2, 3), (1, 2)),
57+
((1, 2, 3), (2, 3)),
58+
((1, 2, 3), (2, 3, 4)),
59+
((1, 2, 3), (2, 3, 4, 5)),
60+
]
61+
)
62+
def test_tile_3D(self, shape, dims):
63+
class Tile(nn.Module):
64+
def forward(self, x):
65+
return torch.ops.aten.tile.default(x, dims)
66+
67+
inputs = [torch.randn(shape)]
68+
self.run_test(
69+
Tile(),
70+
inputs,
71+
)
72+
73+
74+
if __name__ == "__main__":
75+
run_tests()

0 commit comments

Comments
 (0)