Skip to content

Commit 6d739ee

Browse files
committed
chore: dynamic shape support for any/sort/trunc ops
1 parent 622ca53 commit 6d739ee

File tree

6 files changed

+294
-10
lines changed

6 files changed

+294
-10
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
get_positive_dim,
2020
is_only_operator_on_placeholder,
2121
)
22+
from torch_tensorrt.dynamo.utils import TRT_TOPK_MAX_ELEMENT
2223
from torch_tensorrt.fx.types import TRTTensor
2324

2425
_LOGGER: logging.Logger = logging.getLogger(__name__)
@@ -2648,6 +2649,10 @@ def topk_validator(node: Node) -> bool:
26482649

26492650

26502651
def sort_validator(node: Node) -> bool:
2652+
# if meta data is not available(e.g. dynamic shape), validate k value during runtime.
2653+
if not node.args[0].meta:
2654+
return True
2655+
26512656
shape = node.args[0].meta.get("tensor_meta").shape
26522657
dim = node.args[1]
26532658
dim = get_positive_dim(dim, len(shape))
@@ -2656,9 +2661,9 @@ def sort_validator(node: Node) -> bool:
26562661

26572662

26582663
def topk_sort_validator(k: int) -> bool:
2659-
if k > 3840:
2664+
if k > TRT_TOPK_MAX_ELEMENT:
26602665
_LOGGER.debug(
2661-
f"Currently only topk values up to 3840 are supported, got k={k}."
2666+
f"Currently only topk values up to {TRT_TOPK_MAX_ELEMENT} are supported, got k={k}."
26622667
)
26632668
return False
26642669
return True
@@ -3103,7 +3108,9 @@ def aten_ops_topk(
31033108

31043109

31053110
@dynamo_tensorrt_converter(
3106-
torch.ops.aten.sort.default, capability_validator=sort_validator
3111+
torch.ops.aten.sort.default,
3112+
capability_validator=sort_validator,
3113+
supports_dynamic_shapes=True,
31073114
)
31083115
@enforce_tensor_types(
31093116
{
@@ -3128,7 +3135,7 @@ def aten_ops_sort(
31283135
)
31293136

31303137

3131-
@dynamo_tensorrt_converter(torch.ops.aten.trunc.default)
3138+
@dynamo_tensorrt_converter(torch.ops.aten.trunc.default, supports_dynamic_shapes=True)
31323139
@enforce_tensor_types(
31333140
{
31343141
0: (TRTTensor,),
@@ -3204,9 +3211,9 @@ def aten_ops_remainder(
32043211
)
32053212

32063213

3207-
@dynamo_tensorrt_converter(torch.ops.aten.any.default)
3208-
@dynamo_tensorrt_converter(torch.ops.aten.any.dim)
3209-
@dynamo_tensorrt_converter(torch.ops.aten.any.dims)
3214+
@dynamo_tensorrt_converter(torch.ops.aten.any.default, supports_dynamic_shapes=True)
3215+
@dynamo_tensorrt_converter(torch.ops.aten.any.dim, supports_dynamic_shapes=True)
3216+
@dynamo_tensorrt_converter(torch.ops.aten.any.dims, supports_dynamic_shapes=True)
32103217
def aten_ops_any(
32113218
ctx: ConversionContext,
32123219
target: Target,

py/torch_tensorrt/dynamo/conversion/impl/topk.py

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,12 @@
1010
flatten_dims,
1111
get_axes_for_reduce_op,
1212
get_positive_dim,
13+
set_layer_name,
1314
)
14-
from torch_tensorrt.fx.converters.converter_utils import set_layer_name
15-
from torch_tensorrt.fx.types import TRTTensor
15+
from torch_tensorrt.dynamo.conversion.impl.elementwise.ops import le
16+
from torch_tensorrt.dynamo.conversion.impl.shape import get_shape_with_dynamic_shape
17+
from torch_tensorrt.dynamo.types import TRTTensor
18+
from torch_tensorrt.dynamo.utils import DYNAMIC_DIM, TRT_TOPK_MAX_ELEMENT
1619

1720

1821
def argmax_argmin(
@@ -155,6 +158,37 @@ def topk(
155158
k,
156159
get_axes_for_reduce_op(get_positive_dim(dim, len(input.shape))),
157160
)
161+
if k == DYNAMIC_DIM:
162+
output_shape = get_shape_with_dynamic_shape(
163+
ctx, target, source_ir, name, input.shape, input
164+
)
165+
layer = ctx.net.add_slice(
166+
output_shape,
167+
start=[dim],
168+
shape=[1],
169+
stride=[1],
170+
)
171+
set_layer_name(layer, target, name)
172+
173+
# Get scalar tensor from 1d tensor
174+
shuffle_layer = ctx.net.add_shuffle(layer.get_output(0))
175+
shuffle_layer.reshape_dims = trt.Dims()
176+
set_layer_name(shuffle_layer, target, name, source_ir)
177+
178+
cond = le(
179+
ctx,
180+
target,
181+
source_ir,
182+
f"{name}_k_cond",
183+
shuffle_layer.get_output(0),
184+
TRT_TOPK_MAX_ELEMENT,
185+
)
186+
ctx.net.add_assertion(
187+
cond,
188+
message=f"Currently only topk values up to {TRT_TOPK_MAX_ELEMENT} are supported",
189+
)
190+
191+
topk_layer.set_input(1, shuffle_layer.get_output(0))
158192
# TensorRT ITopKLayer does not have a sorted flag, it is always returning the sorted topk elements
159193
# so here no matter sorted is True or False the returned the topk Tensor object is always sorted
160194
set_layer_name(topk_layer, target, name, source_ir)

py/torch_tensorrt/dynamo/utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,14 @@
66
from typing import Any, Callable, Dict, Optional, Sequence, Union
77

88
import numpy as np
9+
import tensorrt as trt
910
import torch
1011
from torch_tensorrt._Device import Device
1112
from torch_tensorrt._enums import dtype
1213
from torch_tensorrt._Input import Input
1314
from torch_tensorrt.dynamo import _defaults
1415
from torch_tensorrt.dynamo._settings import CompilationSettings
1516

16-
import tensorrt as trt
1717
from packaging import version
1818

1919
from .types import TRTDataType
@@ -22,6 +22,7 @@
2222

2323
COSINE_THRESHOLD = 0.99
2424
DYNAMIC_DIM = -1
25+
TRT_TOPK_MAX_ELEMENT = 3840
2526

2627

2728
class Frameworks(Enum):

tests/py/dynamo/conversion/test_any.py

Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import torch.nn as nn
33
from parameterized import parameterized
44
from torch.testing._internal.common_utils import run_tests
5+
from torch_tensorrt import Input
56

67
from .harness import DispatchTestCase
78

@@ -184,5 +185,152 @@ def forward(self, x):
184185
)
185186

186187

188+
class TestAnyConverterDynamic(DispatchTestCase):
189+
@parameterized.expand(
190+
[
191+
(
192+
"3d_dynamic_float",
193+
(2, 1, 1),
194+
(2, 2, 1),
195+
(3, 2, 4),
196+
torch.float,
197+
),
198+
(
199+
"2d_dynamic_int32",
200+
(2, 2),
201+
(2, 2),
202+
(3, 2),
203+
torch.int32,
204+
),
205+
(
206+
"4d_dynamic_bool",
207+
(1, 2, 1, 1),
208+
(2, 2, 2, 2),
209+
(2, 2, 4, 3),
210+
torch.bool,
211+
),
212+
]
213+
)
214+
def test_any_dynamic(self, _, min_shape, opt_shape, max_shape, type):
215+
class Any(nn.Module):
216+
def forward(self, x):
217+
return torch.ops.aten.any.default(x)
218+
219+
input_specs = [
220+
Input(
221+
min_shape=min_shape,
222+
opt_shape=opt_shape,
223+
max_shape=max_shape,
224+
dtype=type,
225+
),
226+
]
227+
self.run_test_with_dynamic_shape(
228+
Any(),
229+
input_specs,
230+
)
231+
232+
@parameterized.expand(
233+
[
234+
(
235+
"3d_dynamic_dim_float",
236+
(2, 1, 1),
237+
(2, 2, 1),
238+
(3, 2, 4),
239+
torch.float,
240+
2,
241+
True,
242+
),
243+
(
244+
"4d_dynamic_dim_int32",
245+
(1, 1, 4, 1),
246+
(2, 2, 4, 2),
247+
(2, 4, 4, 3),
248+
torch.int32,
249+
-2,
250+
False,
251+
),
252+
(
253+
"3d_dynamic_dim_bool",
254+
(2, 1, 1),
255+
(2, 2, 1),
256+
(3, 2, 4),
257+
torch.bool,
258+
0,
259+
True,
260+
),
261+
]
262+
)
263+
def test_any_dynamic_dim(
264+
self, _, min_shape, opt_shape, max_shape, type, dim, keep_dims
265+
):
266+
class AnyDim(nn.Module):
267+
def forward(self, x):
268+
return torch.ops.aten.any.dim(x, dim, keep_dims)
269+
270+
input_specs = [
271+
Input(
272+
min_shape=min_shape,
273+
opt_shape=opt_shape,
274+
max_shape=max_shape,
275+
dtype=type,
276+
),
277+
]
278+
self.run_test_with_dynamic_shape(
279+
AnyDim(),
280+
input_specs,
281+
)
282+
283+
@parameterized.expand(
284+
[
285+
(
286+
"3d_dynamic_dims_float",
287+
(2, 1, 1),
288+
(2, 2, 1),
289+
(3, 2, 4),
290+
torch.float,
291+
[1, 2],
292+
True,
293+
),
294+
(
295+
"4d_dynamic_dims_int32",
296+
(1, 1, 4, 1),
297+
(2, 2, 4, 2),
298+
(2, 4, 4, 3),
299+
torch.int32,
300+
[2, -1],
301+
False,
302+
),
303+
(
304+
"3d_dynamic_dims_bool",
305+
(1, 4, 1),
306+
(2, 4, 2),
307+
(4, 4, 3),
308+
torch.bool,
309+
[0, 1, 2],
310+
False,
311+
),
312+
]
313+
)
314+
def test_any_dynamic_dims(
315+
self, _, min_shape, opt_shape, max_shape, type, dims, keep_dims
316+
):
317+
class AnyDims(nn.Module):
318+
def forward(self, x):
319+
return torch.ops.aten.any.dims(x, dims, keep_dims)
320+
321+
input_specs = [
322+
Input(
323+
min_shape=min_shape,
324+
opt_shape=opt_shape,
325+
max_shape=max_shape,
326+
dtype=type,
327+
),
328+
]
329+
self.run_test_with_dynamic_shape(
330+
AnyDims(),
331+
input_specs,
332+
)
333+
334+
187335
if __name__ == "__main__":
188336
run_tests()

tests/py/dynamo/conversion/test_sort_aten.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import torch.nn as nn
33
from parameterized import parameterized
44
from torch.testing._internal.common_utils import run_tests
5+
from torch_tensorrt import Input
56

67
from .harness import DispatchTestCase
78

@@ -32,5 +33,54 @@ def forward(self, x):
3233
)
3334

3435

36+
class TestSortConverterDynamic(DispatchTestCase):
37+
@parameterized.expand(
38+
[
39+
(
40+
"3d_dynamic_descending",
41+
(2, 2, 1),
42+
(2, 2, 1),
43+
(3, 2, 4),
44+
0,
45+
True,
46+
),
47+
(
48+
"4d_dynamic_ascending",
49+
(2, 2, 1, 1),
50+
(2, 2, 1, 2),
51+
(3, 3, 2, 4),
52+
3,
53+
False,
54+
),
55+
(
56+
"4d_dynamic_descending_neg_dim",
57+
(2, 2, 1, 1),
58+
(2, 2, 1, 2),
59+
(3, 3, 2, 4),
60+
-3,
61+
True,
62+
),
63+
]
64+
)
65+
def test_sort_dynamic(self, _, min_shape, opt_shape, max_shape, dim, descending):
66+
class Sort(nn.Module):
67+
def forward(self, x):
68+
return torch.ops.aten.sort.default(x, dim, descending)
69+
70+
input_specs = [
71+
Input(
72+
min_shape=min_shape,
73+
opt_shape=opt_shape,
74+
max_shape=max_shape,
75+
dtype=torch.float,
76+
),
77+
]
78+
self.run_test_with_dynamic_shape(
79+
Sort(),
80+
input_specs,
81+
output_dtypes=[torch.float, torch.int64],
82+
)
83+
84+
3585
if __name__ == "__main__":
3686
run_tests()

0 commit comments

Comments
 (0)