Skip to content

Commit 2f569c3

Browse files
authored
feat: support aten.sort dynamo converter (#2514)
1 parent f0e6d2d commit 2f569c3

File tree

3 files changed

+92
-1
lines changed

3 files changed

+92
-1
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2505,3 +2505,27 @@ def upsample_bilinear2d(
25052505
resize_mode="bilinear",
25062506
align_corners=args_bounds_check(args, 2),
25072507
)
2508+
2509+
2510+
@dynamo_tensorrt_converter(torch.ops.aten.sort.default)
2511+
@enforce_tensor_types(
2512+
{
2513+
0: (TRTTensor,),
2514+
}
2515+
)
2516+
def aten_ops_sort(
2517+
ctx: ConversionContext,
2518+
target: Target,
2519+
args: Tuple[Argument, ...],
2520+
kwargs: Dict[str, Argument],
2521+
name: str,
2522+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
2523+
return impl.topk.sort(
2524+
ctx,
2525+
target,
2526+
SourceIR.ATEN,
2527+
name,
2528+
args[0],
2529+
dim=args_bounds_check(args, 1, -1),
2530+
descending=args_bounds_check(args, 2, False),
2531+
)

py/torch_tensorrt/dynamo/conversion/impl/topk.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Optional
1+
from typing import Optional, Tuple, Union
22

33
import tensorrt as trt
44
from torch.fx.node import Target
@@ -101,3 +101,36 @@ def argmin(
101101
return argmax_argmin(
102102
ctx, target, source_ir, name, input, trt.TopKOperation.MIN, dim, keep_dim
103103
)
104+
105+
106+
def sort(
107+
ctx: ConversionContext,
108+
target: Target,
109+
source_ir: Optional[SourceIR],
110+
name: str,
111+
input: TRTTensor,
112+
dim: int,
113+
descending: bool,
114+
return_indices: bool = True,
115+
) -> Union[TRTTensor, Tuple[TRTTensor, TRTTensor]]:
116+
if descending:
117+
topk_layer = ctx.net.add_topk(
118+
input,
119+
trt.TopKOperation.MAX,
120+
input.shape[dim],
121+
get_axes_for_reduce_op(get_positive_dim(dim, len(input.shape))),
122+
)
123+
else:
124+
topk_layer = ctx.net.add_topk(
125+
input,
126+
trt.TopKOperation.MIN,
127+
input.shape[dim],
128+
get_axes_for_reduce_op(get_positive_dim(dim, len(input.shape))),
129+
)
130+
131+
set_layer_name(topk_layer, target, name, source_ir)
132+
133+
if return_indices:
134+
return topk_layer.get_output(0), topk_layer.get_output(1)
135+
else:
136+
return topk_layer.get_output(0)
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
import torch
2+
import torch.nn as nn
3+
from parameterized import parameterized
4+
from torch.testing._internal.common_utils import run_tests
5+
6+
from .harness import DispatchTestCase
7+
8+
9+
class TestSortConverter(DispatchTestCase):
10+
@parameterized.expand(
11+
[
12+
((3, 2, 4), 0, True),
13+
((2, 3, 4, 5), 1, True),
14+
((2, 3, 4, 5), 2, False),
15+
((6, 7, 5, 4, 5), 4, False),
16+
((1, 5, 2, 1), -1, True),
17+
((1, 2, 5, 3), -2, False),
18+
((6, 2, 1, 3), -4, True),
19+
]
20+
)
21+
def test_sort(self, input_shape, dim, descending):
22+
class Sort(nn.Module):
23+
def forward(self, x):
24+
return torch.ops.aten.sort.default(x, dim, descending)
25+
26+
inputs = [torch.randn(*input_shape)]
27+
self.run_test(
28+
Sort(),
29+
inputs,
30+
)
31+
32+
33+
if __name__ == "__main__":
34+
run_tests()

0 commit comments

Comments
 (0)