Skip to content

Commit fd19353

Browse files
authored
feat: support aten.flip dynamo converter (#2540)
1 parent b8403b8 commit fd19353

File tree

3 files changed

+94
-0
lines changed

3 files changed

+94
-0
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2626,3 +2626,26 @@ def aten_ops_pdist(
26262626
args[0],
26272627
args_bounds_check(args, 1, 2),
26282628
)
2629+
2630+
2631+
@dynamo_tensorrt_converter(torch.ops.aten.flip.default)
2632+
@enforce_tensor_types(
2633+
{
2634+
0: (TRTTensor,),
2635+
}
2636+
)
2637+
def aten_ops_flip(
2638+
ctx: ConversionContext,
2639+
target: Target,
2640+
args: Tuple[Argument, ...],
2641+
kwargs: Dict[str, Argument],
2642+
name: str,
2643+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
2644+
return impl.slice.flip(
2645+
ctx,
2646+
target,
2647+
SourceIR.ATEN,
2648+
name,
2649+
args[0],
2650+
args[1],
2651+
)

py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,3 +225,37 @@ def tile(
225225
layer.mode = trt.SampleMode.WRAP
226226
set_layer_name(layer, target, name)
227227
return layer.get_output(0)
228+
229+
230+
def flip(
231+
ctx: ConversionContext,
232+
target: Target,
233+
source_ir: Optional[SourceIR],
234+
name: str,
235+
input: TRTTensor,
236+
dims: Sequence[int],
237+
) -> TRTTensor:
238+
start_slice = []
239+
output_shape = list(input.shape)
240+
stride_slice = []
241+
242+
shape = input.shape
243+
rank = len(shape)
244+
dims = get_positive_dim(dims, rank)
245+
246+
for i in range(rank):
247+
if i in dims:
248+
start_slice.append(shape[i] - 1)
249+
stride_slice.append(-1)
250+
else:
251+
start_slice.append(0)
252+
stride_slice.append(1)
253+
254+
layer = ctx.net.add_slice(
255+
input,
256+
start=start_slice,
257+
shape=output_shape,
258+
stride=stride_slice,
259+
)
260+
set_layer_name(layer, target, name, source_ir)
261+
return layer.get_output(0)
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
import torch
2+
import torch.nn as nn
3+
from parameterized import parameterized
4+
from torch.testing._internal.common_utils import run_tests
5+
6+
from .harness import DispatchTestCase
7+
8+
9+
class TestFlipConverter(DispatchTestCase):
10+
@parameterized.expand(
11+
[
12+
((3,), [0]),
13+
((3,), [-1]),
14+
((3,), []),
15+
((3, 3), [0, 1]),
16+
((3, 3), [-2, 1]),
17+
((2, 3, 4), [0]),
18+
((3, 3, 3), (0, 1)),
19+
((2, 3, 4), [0, 1, 2]),
20+
((2, 3, 4), [-3, -2, -1]),
21+
((3, 3, 3, 3), [0]),
22+
((2, 3, 4, 5), [0, 1, 2, 3]),
23+
((2, 3, 4, 5), [-4, 1, -2, 3]),
24+
((2, 3, 4, 5), []),
25+
]
26+
)
27+
def test_flip(self, shape, dims):
28+
class Flip(nn.Module):
29+
def forward(self, x):
30+
return torch.ops.aten.flip.default(x, dims)
31+
32+
inputs = [torch.randn(shape)]
33+
self.run_test(Flip(), inputs)
34+
35+
36+
if __name__ == "__main__":
37+
run_tests()

0 commit comments

Comments
 (0)