Skip to content

Commit e1491b0

Browse files
committed
chore: Cannot support dynamic k in topk
1 parent 6d739ee commit e1491b0

File tree

4 files changed

+22
-50
lines changed

4 files changed

+22
-50
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
get_positive_dim,
2020
is_only_operator_on_placeholder,
2121
)
22-
from torch_tensorrt.dynamo.utils import TRT_TOPK_MAX_ELEMENT
2322
from torch_tensorrt.fx.types import TRTTensor
2423

2524
_LOGGER: logging.Logger = logging.getLogger(__name__)
@@ -2649,21 +2648,22 @@ def topk_validator(node: Node) -> bool:
26492648

26502649

26512650
def sort_validator(node: Node) -> bool:
2652-
# if meta data is not available(e.g. dynamic shape), validate k value during runtime.
2653-
if not node.args[0].meta:
2654-
return True
2655-
2656-
shape = node.args[0].meta.get("tensor_meta").shape
2651+
meta_data = node.args[0].meta.get("tensor_meta")
2652+
if meta_data is None:
2653+
return False
2654+
shape = meta_data.shape
26572655
dim = node.args[1]
26582656
dim = get_positive_dim(dim, len(shape))
26592657
k = shape[dim]
2658+
if not isinstance(k, int):
2659+
return False
26602660
return topk_sort_validator(k)
26612661

26622662

26632663
def topk_sort_validator(k: int) -> bool:
2664-
if k > TRT_TOPK_MAX_ELEMENT:
2664+
if k > 3840:
26652665
_LOGGER.debug(
2666-
f"Currently only topk values up to {TRT_TOPK_MAX_ELEMENT} are supported, got k={k}."
2666+
f"Currently only topk values up to 3840 are supported, got k={k}."
26672667
)
26682668
return False
26692669
return True

py/torch_tensorrt/dynamo/conversion/impl/topk.py

Lines changed: 5 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,8 @@
1212
get_positive_dim,
1313
set_layer_name,
1414
)
15-
from torch_tensorrt.dynamo.conversion.impl.elementwise.ops import le
16-
from torch_tensorrt.dynamo.conversion.impl.shape import get_shape_with_dynamic_shape
1715
from torch_tensorrt.dynamo.types import TRTTensor
18-
from torch_tensorrt.dynamo.utils import DYNAMIC_DIM, TRT_TOPK_MAX_ELEMENT
16+
from torch_tensorrt.dynamo.utils import DYNAMIC_DIM
1917

2018

2119
def argmax_argmin(
@@ -158,40 +156,14 @@ def topk(
158156
k,
159157
get_axes_for_reduce_op(get_positive_dim(dim, len(input.shape))),
160158
)
161-
if k == DYNAMIC_DIM:
162-
output_shape = get_shape_with_dynamic_shape(
163-
ctx, target, source_ir, name, input.shape, input
164-
)
165-
layer = ctx.net.add_slice(
166-
output_shape,
167-
start=[dim],
168-
shape=[1],
169-
stride=[1],
170-
)
171-
set_layer_name(layer, target, name)
172159

173-
# Get scalar tensor from 1d tensor
174-
shuffle_layer = ctx.net.add_shuffle(layer.get_output(0))
175-
shuffle_layer.reshape_dims = trt.Dims()
176-
set_layer_name(shuffle_layer, target, name, source_ir)
160+
# topk layer supports dynamic k value but we cannot dertermin supported dynamic topk value at
161+
# compile time.
162+
assert k != DYNAMIC_DIM, "k value cannot be dynamic!"
177163

178-
cond = le(
179-
ctx,
180-
target,
181-
source_ir,
182-
f"{name}_k_cond",
183-
shuffle_layer.get_output(0),
184-
TRT_TOPK_MAX_ELEMENT,
185-
)
186-
ctx.net.add_assertion(
187-
cond,
188-
message=f"Currently only topk values up to {TRT_TOPK_MAX_ELEMENT} are supported",
189-
)
190-
191-
topk_layer.set_input(1, shuffle_layer.get_output(0))
192164
# TensorRT ITopKLayer does not have a sorted flag, it is always returning the sorted topk elements
193165
# so here no matter sorted is True or False the returned the topk Tensor object is always sorted
194-
set_layer_name(topk_layer, target, name, source_ir)
166+
set_layer_name(topk_layer, target, f"{name}_topk", source_ir)
195167

196168
if return_indices:
197169
return topk_layer.get_output(0), topk_layer.get_output(1)

py/torch_tensorrt/dynamo/utils.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,14 @@
66
from typing import Any, Callable, Dict, Optional, Sequence, Union
77

88
import numpy as np
9-
import tensorrt as trt
109
import torch
1110
from torch_tensorrt._Device import Device
1211
from torch_tensorrt._enums import dtype
1312
from torch_tensorrt._Input import Input
1413
from torch_tensorrt.dynamo import _defaults
1514
from torch_tensorrt.dynamo._settings import CompilationSettings
1615

16+
import tensorrt as trt
1717
from packaging import version
1818

1919
from .types import TRTDataType
@@ -22,7 +22,6 @@
2222

2323
COSINE_THRESHOLD = 0.99
2424
DYNAMIC_DIM = -1
25-
TRT_TOPK_MAX_ELEMENT = 3840
2625

2726

2827
class Frameworks(Enum):

tests/py/dynamo/conversion/test_sort_aten.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -38,24 +38,24 @@ class TestSortConverterDynamic(DispatchTestCase):
3838
[
3939
(
4040
"3d_dynamic_descending",
41-
(2, 2, 1),
42-
(2, 2, 1),
41+
(2, 1, 4),
4342
(3, 2, 4),
44-
0,
43+
(3, 3, 4),
44+
2,
4545
True,
4646
),
4747
(
4848
"4d_dynamic_ascending",
49-
(2, 2, 1, 1),
50-
(2, 2, 1, 2),
49+
(2, 2, 1, 4),
50+
(2, 2, 2, 4),
5151
(3, 3, 2, 4),
5252
3,
5353
False,
5454
),
5555
(
5656
"4d_dynamic_descending_neg_dim",
57-
(2, 2, 1, 1),
58-
(2, 2, 1, 2),
57+
(1, 3, 1, 1),
58+
(2, 3, 2, 2),
5959
(3, 3, 2, 4),
6060
-3,
6161
True,
@@ -79,6 +79,7 @@ def forward(self, x):
7979
Sort(),
8080
input_specs,
8181
output_dtypes=[torch.float, torch.int64],
82+
use_dynamo_tracer=True,
8283
)
8384

8485

0 commit comments

Comments
 (0)