Skip to content

Commit 58c8f2f

Browse files
committed
feat: support tile dynamo converter
1 parent acc248b commit 58c8f2f

File tree

3 files changed

+137
-1
lines changed

3 files changed

+137
-1
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -691,6 +691,29 @@ def aten_ops_chunk(
691691
)
692692

693693

694+
@dynamo_tensorrt_converter(torch.ops.aten.tile.default) # type: ignore[misc]
695+
@enforce_tensor_types(
696+
{
697+
0: (TRTTensor,),
698+
}
699+
) # type: ignore[misc]
700+
def aten_ops_tile(
701+
ctx: ConversionContext,
702+
target: Target,
703+
args: Tuple[Argument, ...],
704+
kwargs: Dict[str, Argument],
705+
name: str,
706+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
707+
return impl.slice.tile(
708+
ctx,
709+
target,
710+
SourceIR.ATEN,
711+
name,
712+
args[0],
713+
args[1],
714+
)
715+
716+
694717
@dynamo_tensorrt_converter(torch.ops.aten.permute.default) # type: ignore[misc]
695718
@enforce_tensor_types(
696719
{

py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
import math
2-
from typing import Optional
2+
from typing import Optional, Sequence
33

4+
import tensorrt as trt
45
from torch.fx.node import Target
56
from torch_tensorrt.dynamo._SourceIR import SourceIR
7+
from torch_tensorrt.dynamo.conversion import impl
68
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
79
from torch_tensorrt.dynamo.conversion.converter_utils import get_positive_dim
810
from torch_tensorrt.dynamo.conversion.impl.slice.base import slice
@@ -157,3 +159,39 @@ def chunk(
157159
cnt += 1
158160

159161
return result
162+
163+
164+
def tile(
165+
ctx: ConversionContext,
166+
target: Target,
167+
source_ir: Optional[SourceIR],
168+
name: str,
169+
input: TRTTensor,
170+
dims: Sequence[int],
171+
) -> TRTTensor:
172+
diff = len(dims) - len(input.shape)
173+
if diff > 0:
174+
# prepend 1 to input.shape
175+
new_shape = (1,) * diff + tuple(input.shape)
176+
input = impl.shuffle.reshape(
177+
ctx, target, source_ir, f"{name}_prepend_input_shape", input, new_shape
178+
)
179+
elif diff < 0:
180+
# prepend 1 to dims
181+
dims = (1,) * -diff + tuple(dims)
182+
183+
if all(isinstance(d, int) for d in dims):
184+
shapes = [i * j for i, j in zip(input.shape, dims)]
185+
else:
186+
shapes = []
187+
for i, (s, d) in enumerate(zip(input.shape, dims)):
188+
shapes.append(
189+
impl.elementwise.mul(ctx, target, source_ir, f"{name}_mul_{i}", s, d)
190+
)
191+
192+
starts = [0] * len(dims)
193+
strides = [1] * len(dims)
194+
layer = ctx.net.add_slice(input, tuple(starts), tuple(shapes), tuple(strides))
195+
layer.mode = trt.SampleMode.WRAP
196+
set_layer_name(layer, target, name)
197+
return layer.get_output(0)
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
import torch
2+
import torch.nn as nn
3+
from parameterized import parameterized
4+
from torch.testing._internal.common_utils import run_tests
5+
6+
from .harness import DispatchTestCase
7+
8+
9+
class TestTileConverter(DispatchTestCase):
10+
@parameterized.expand(
11+
[
12+
((3,), (1,)),
13+
((3,), (0,)),
14+
((3,), (2,)),
15+
((2,), (2, 2)),
16+
((2,), (0, 2)),
17+
]
18+
)
19+
def test_tile_1D(self, shape, dims):
20+
class Tile(nn.Module):
21+
def forward(self, x):
22+
return torch.ops.aten.tile.default(x, dims)
23+
24+
inputs = [torch.randn(shape)]
25+
self.run_test(
26+
Tile(),
27+
inputs,
28+
)
29+
30+
@parameterized.expand(
31+
[
32+
((3, 1), (0,)),
33+
((3, 1), (2,)),
34+
((2, 3), (2, 2)),
35+
((2, 3), (1, 0)),
36+
((2, 3), (0, 2)),
37+
((2, 3), (4, 2, 3)),
38+
((2, 3), (0, 0, 3)),
39+
((2, 3), (4, 2, 3, 1, 2)),
40+
]
41+
)
42+
def test_tile_2D(self, shape, dims):
43+
class Tile(nn.Module):
44+
def forward(self, x):
45+
return torch.ops.aten.tile.default(x, dims)
46+
47+
inputs = [torch.randn(shape)]
48+
self.run_test(
49+
Tile(),
50+
inputs,
51+
)
52+
53+
@parameterized.expand(
54+
[
55+
((4, 2, 3), (2,)),
56+
((4, 2, 3), (1, 2)),
57+
((1, 2, 3), (2, 3)),
58+
((1, 2, 3), (2, 3, 4)),
59+
((1, 2, 3), (2, 3, 4, 5)),
60+
]
61+
)
62+
def test_tile_3D(self, shape, dims):
63+
class Tile(nn.Module):
64+
def forward(self, x):
65+
return torch.ops.aten.tile.default(x, dims)
66+
67+
inputs = [torch.randn(shape)]
68+
self.run_test(
69+
Tile(),
70+
inputs,
71+
)
72+
73+
74+
if __name__ == "__main__":
75+
run_tests()

0 commit comments

Comments
 (0)