Skip to content

Commit fc9960a

Browse files
committed
feat: support aten.flip dynamo converter
1 parent df03def commit fc9960a

File tree

3 files changed

+87
-0
lines changed

3 files changed

+87
-0
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2541,3 +2541,26 @@ def aten_ops_sort(
25412541
dim=args_bounds_check(args, 1, -1),
25422542
descending=args_bounds_check(args, 2, False),
25432543
)
2544+
2545+
2546+
@dynamo_tensorrt_converter(torch.ops.aten.flip.default)
2547+
@enforce_tensor_types(
2548+
{
2549+
0: (TRTTensor,),
2550+
}
2551+
)
2552+
def aten_ops_flip(
2553+
ctx: ConversionContext,
2554+
target: Target,
2555+
args: Tuple[Argument, ...],
2556+
kwargs: Dict[str, Argument],
2557+
name: str,
2558+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
2559+
return impl.slice.flip(
2560+
ctx,
2561+
target,
2562+
SourceIR.ATEN,
2563+
name,
2564+
args[0],
2565+
args[1],
2566+
)

py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,3 +225,36 @@ def tile(
225225
layer.mode = trt.SampleMode.WRAP
226226
set_layer_name(layer, target, name)
227227
return layer.get_output(0)
228+
229+
230+
def flip(
231+
ctx: ConversionContext,
232+
target: Target,
233+
source_ir: Optional[SourceIR],
234+
name: str,
235+
input: TRTTensor,
236+
dims: Sequence[int],
237+
) -> TRTTensor:
238+
start_slice = []
239+
output_shape = list(input.shape)
240+
stride_slice = []
241+
242+
shape = input.shape
243+
rank = len(shape)
244+
245+
for i in range(rank):
246+
if i in dims:
247+
start_slice.append(shape[i] - 1)
248+
stride_slice.append(-1)
249+
else:
250+
start_slice.append(0)
251+
stride_slice.append(1)
252+
253+
layer = ctx.net.add_slice(
254+
input,
255+
start=start_slice,
256+
shape=output_shape,
257+
stride=stride_slice,
258+
)
259+
set_layer_name(layer, target, name, source_ir)
260+
return layer.get_output(0)
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
import torch
2+
import torch.nn as nn
3+
from parameterized import parameterized
4+
from torch.testing._internal.common_utils import run_tests
5+
6+
from .harness import DispatchTestCase
7+
8+
9+
class TestFlipConverter(DispatchTestCase):
10+
@parameterized.expand(
11+
[
12+
((3,), [0]),
13+
((3, 3), [0, 1]),
14+
((2, 3, 4), [0]),
15+
((3, 3, 3), (0, 1)),
16+
((2, 3, 4), [0, 1, 2]),
17+
((3, 3, 3, 3), [0]),
18+
((2, 3, 4, 5), [0, 1, 2, 3]),
19+
]
20+
)
21+
def test_flip(self, shape, dims):
22+
class Flip(nn.Module):
23+
def forward(self, x):
24+
return torch.ops.aten.flip.default(x, dims)
25+
26+
inputs = [torch.randn(shape)]
27+
self.run_test(Flip(), inputs)
28+
29+
30+
if __name__ == "__main__":
31+
run_tests()

0 commit comments

Comments
 (0)