Skip to content

Commit 3390e24

Browse files
authored
feat: support aten.roll dynamo converter (#2569)
1 parent cf3a688 commit 3390e24

File tree

3 files changed

+132
-2
lines changed

3 files changed

+132
-2
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2706,3 +2706,27 @@ def aten_ops_scalar_tensor(
27062706
return impl.unary.scalar_tensor(
27072707
ctx, target, SourceIR.ATEN, name, args[0], dtype=kwargs.get("dtype")
27082708
)
2709+
2710+
2711+
@dynamo_tensorrt_converter(torch.ops.aten.roll.default)
2712+
@enforce_tensor_types(
2713+
{
2714+
0: (TRTTensor,),
2715+
}
2716+
)
2717+
def aten_ops_roll(
2718+
ctx: ConversionContext,
2719+
target: Target,
2720+
args: Tuple[Argument, ...],
2721+
kwargs: Dict[str, Argument],
2722+
name: str,
2723+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
2724+
return impl.permutation.roll(
2725+
ctx,
2726+
target,
2727+
SourceIR.ATEN,
2728+
name,
2729+
args[0],
2730+
args[1],
2731+
args_bounds_check(args, 2, []),
2732+
)

py/torch_tensorrt/dynamo/conversion/impl/permutation.py

Lines changed: 65 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,14 @@
1-
from typing import Optional, Sequence
1+
from typing import Optional, Sequence, Union
22

3+
import tensorrt as trt
34
from torch.fx.node import Target
45
from torch_tensorrt.dynamo._SourceIR import SourceIR
6+
from torch_tensorrt.dynamo.conversion import impl
57
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
6-
from torch_tensorrt.dynamo.conversion.converter_utils import get_positive_dim
8+
from torch_tensorrt.dynamo.conversion.converter_utils import (
9+
flatten_dims,
10+
get_positive_dim,
11+
)
712
from torch_tensorrt.fx.converters.converter_utils import set_layer_name
813
from torch_tensorrt.fx.types import TRTTensor
914

@@ -27,3 +32,61 @@ def permute(
2732
layer.second_transpose = tuple(permutation)
2833
set_layer_name(layer, target, name, source_ir)
2934
return layer.get_output(0)
35+
36+
37+
def roll(
38+
ctx: ConversionContext,
39+
target: Target,
40+
source_ir: Optional[SourceIR],
41+
name: str,
42+
input: TRTTensor,
43+
shifts: Union[int, Sequence[int]],
44+
dims: Union[int, Sequence[int]],
45+
) -> TRTTensor:
46+
shape = input.shape
47+
if isinstance(shifts, int):
48+
shifts = [shifts]
49+
if isinstance(dims, int):
50+
dims = [dims]
51+
52+
if dims != []:
53+
rank = len(shape)
54+
start = [0] * rank
55+
stride = [1] * rank
56+
for i in range(len(dims)):
57+
d = dims[i]
58+
s = shifts[i]
59+
start[d] += get_positive_dim(
60+
-s, shape[d]
61+
) # in case that dims has multiple same dim
62+
63+
layer = ctx.net.add_slice(
64+
input,
65+
start=start,
66+
shape=shape,
67+
stride=stride,
68+
)
69+
layer.mode = trt.SliceMode.WRAP
70+
set_layer_name(layer, target, f"{name}_slice_wrap", source_ir)
71+
return layer.get_output(0)
72+
73+
else:
74+
flatten_shape = flatten_dims(input, 0, -1)
75+
output = impl.shuffle.reshape(
76+
ctx, target, source_ir, f"{name}_reshape", input, flatten_shape
77+
)
78+
start = [get_positive_dim(-shifts[0], output.shape[0])]
79+
stride = [1]
80+
layer = ctx.net.add_slice(
81+
output,
82+
start=start,
83+
shape=flatten_shape,
84+
stride=stride,
85+
)
86+
layer.mode = trt.SliceMode.WRAP
87+
set_layer_name(layer, target, f"{name}_slice_wrap", source_ir)
88+
output = layer.get_output(0)
89+
output = impl.shuffle.reshape(
90+
ctx, target, source_ir, f"{name}_reshape_back", output, shape
91+
)
92+
return output
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
import torch
2+
import torch.nn as nn
3+
from parameterized import parameterized
4+
from torch.testing._internal.common_utils import run_tests
5+
from torch_tensorrt import Input
6+
7+
from .harness import DispatchTestCase
8+
9+
10+
class TestRollConverter(DispatchTestCase):
11+
@parameterized.expand(
12+
[
13+
((4,), (2,), 0),
14+
((4,), [2], [0]),
15+
((4,), [3], [0]),
16+
((4,), [-3, 2], [0, 0]),
17+
((4,), [-2], []),
18+
((4, 2), [2, 1], [0, 1]),
19+
((3, 3), [2, 1], [1, 1]),
20+
((4, 2), [2, -1], [-2, -1]),
21+
((4, 2), [4], []),
22+
((3, 4, 2), [1, 0, 2], [2, 0, -2]),
23+
((3, 4, 2), [1, -0, 2], [1, 1, 1]),
24+
(
25+
(3, 4, 2),
26+
[
27+
5,
28+
],
29+
[],
30+
),
31+
]
32+
)
33+
def test_roll(self, shape, shifts, dims):
34+
class Roll(nn.Module):
35+
def forward(self, x):
36+
return torch.ops.aten.roll.default(x, shifts, dims)
37+
38+
inputs = [torch.randn(shape)]
39+
self.run_test(Roll(), inputs)
40+
41+
42+
if __name__ == "__main__":
43+
run_tests()

0 commit comments

Comments
 (0)