Skip to content

Lluo/aten topk #2840

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docsrc/dynamo/torch_compile.rst
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ Custom Setting Usage
...
optimized_model = torch.compile(model, backend="torch_tensorrt", dynamic=False,
options={"truncate_long_and_double": True,
"precision": torch.half,
"enabled_precisions": [torch.half],
"debug": True,
"min_block_size": 2,
"torch_executed_ops": {"torch.ops.aten.sub.Tensor"},
Expand Down
26 changes: 26 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -2667,6 +2667,32 @@ def upsample_bilinear2d(
)


@dynamo_tensorrt_converter(torch.ops.aten.topk.default)
@enforce_tensor_types(
{
0: (TRTTensor,),
}
)
def aten_ops_topk(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.topk.topk(
ctx,
target,
SourceIR.ATEN,
name,
args[0],
k=args[1],
dim=args_bounds_check(args, 2, -1),
largest=args_bounds_check(args, 3, True),
sorted=args_bounds_check(args, 4, True),
)


@dynamo_tensorrt_converter(torch.ops.aten.sort.default)
@enforce_tensor_types(
{
Expand Down
35 changes: 35 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/impl/topk.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi can you consolidate the topk() and the sort() function, something like the sort() function calls the topk function with the k arg as the dim?

Also in the below the sorted argument should also be handled to return the elements in sorted order.

Original file line number Diff line number Diff line change
Expand Up @@ -134,3 +134,38 @@ def sort(
return topk_layer.get_output(0), topk_layer.get_output(1)
else:
return topk_layer.get_output(0)


def topk(
ctx: ConversionContext,
target: Target,
source_ir: Optional[SourceIR],
name: str,
input: TRTTensor,
k: int,
dim: int,
largest: bool,
sorted: bool,
return_indices: bool = True,
) -> Union[TRTTensor, Tuple[TRTTensor, TRTTensor]]:
if largest:
topk_layer = ctx.net.add_topk(
input,
trt.TopKOperation.MAX,
k,
get_axes_for_reduce_op(get_positive_dim(dim, len(input.shape))),
)
else:
topk_layer = ctx.net.add_topk(
input,
trt.TopKOperation.MIN,
k,
get_axes_for_reduce_op(get_positive_dim(dim, len(input.shape))),
)

set_layer_name(topk_layer, target, name, source_ir)

if return_indices:
return topk_layer.get_output(0), topk_layer.get_output(1)
else:
return topk_layer.get_output(0)
31 changes: 31 additions & 0 deletions tests/py/dynamo/conversion/test_topk_aten.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add more test cases with 2D and 3D cases? Also cases on how it would handle duplicate values.
Also could you once print the decomposed graph to see if aten.topk.value is invoked or not.

Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import torch
import torch.nn as nn
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests

from .harness import DispatchTestCase


class TestSortConverter(DispatchTestCase):
@parameterized.expand(
[
((3, 2, 4), 1, 0, True, True),
((3850, 2), 3840, 0, False, True),
# default dim:-1 largest:True, sorted:True
((3, 5, 12), 3),
]
)
def test_topk(self, input_shape, k, dim=-1, largest=True, sorted=True):
class Topk(nn.Module):
def forward(self, x):
return torch.ops.aten.topk.default(x, k, dim, largest, sorted)

inputs = [torch.randn(*input_shape)]
self.run_test(
Topk(),
inputs,
)


if __name__ == "__main__":
run_tests()
Loading