Skip to content

Commit 23b4f1e

Browse files
authored
chore: dynamic shape support for any/sort/trunc ops (#3026)
1 parent 784fa57 commit 23b4f1e

File tree

5 files changed

+265
-9
lines changed

5 files changed

+265
-9
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2668,10 +2668,15 @@ def topk_validator(node: Node) -> bool:
26682668

26692669

26702670
def sort_validator(node: Node) -> bool:
2671-
shape = node.args[0].meta.get("tensor_meta").shape
2671+
meta_data = node.args[0].meta.get("tensor_meta")
2672+
if meta_data is None:
2673+
return False
2674+
shape = meta_data.shape
26722675
dim = node.args[1]
26732676
dim = get_positive_dim(dim, len(shape))
26742677
k = shape[dim]
2678+
if not isinstance(k, int):
2679+
return False
26752680
return topk_sort_validator(k)
26762681

26772682

@@ -3436,7 +3441,9 @@ def aten_ops_topk(
34363441

34373442

34383443
@dynamo_tensorrt_converter(
3439-
torch.ops.aten.sort.default, capability_validator=sort_validator
3444+
torch.ops.aten.sort.default,
3445+
capability_validator=sort_validator,
3446+
supports_dynamic_shapes=True,
34403447
)
34413448
@enforce_tensor_types(
34423449
{
@@ -3461,7 +3468,7 @@ def aten_ops_sort(
34613468
)
34623469

34633470

3464-
@dynamo_tensorrt_converter(torch.ops.aten.trunc.default)
3471+
@dynamo_tensorrt_converter(torch.ops.aten.trunc.default, supports_dynamic_shapes=True)
34653472
@enforce_tensor_types(
34663473
{
34673474
0: (TRTTensor,),
@@ -3537,9 +3544,9 @@ def aten_ops_remainder(
35373544
)
35383545

35393546

3540-
@dynamo_tensorrt_converter(torch.ops.aten.any.default)
3541-
@dynamo_tensorrt_converter(torch.ops.aten.any.dim)
3542-
@dynamo_tensorrt_converter(torch.ops.aten.any.dims)
3547+
@dynamo_tensorrt_converter(torch.ops.aten.any.default, supports_dynamic_shapes=True)
3548+
@dynamo_tensorrt_converter(torch.ops.aten.any.dim, supports_dynamic_shapes=True)
3549+
@dynamo_tensorrt_converter(torch.ops.aten.any.dims, supports_dynamic_shapes=True)
35433550
def aten_ops_any(
35443551
ctx: ConversionContext,
35453552
target: Target,

py/torch_tensorrt/dynamo/conversion/impl/topk.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,10 @@
1010
flatten_dims,
1111
get_axes_for_reduce_op,
1212
get_positive_dim,
13+
set_layer_name,
1314
)
14-
from torch_tensorrt.fx.converters.converter_utils import set_layer_name
15-
from torch_tensorrt.fx.types import TRTTensor
15+
from torch_tensorrt.dynamo.types import TRTTensor
16+
from torch_tensorrt.dynamo.utils import DYNAMIC_DIM
1617

1718

1819
def argmax_argmin(
@@ -155,9 +156,14 @@ def topk(
155156
k,
156157
get_axes_for_reduce_op(get_positive_dim(dim, len(input.shape))),
157158
)
159+
160+
# topk layer supports dynamic k value but we cannot dertermin supported dynamic topk value at
161+
# compile time.
162+
assert k != DYNAMIC_DIM, "k value cannot be dynamic!"
163+
158164
# TensorRT ITopKLayer does not have a sorted flag, it is always returning the sorted topk elements
159165
# so here no matter sorted is True or False the returned the topk Tensor object is always sorted
160-
set_layer_name(topk_layer, target, name, source_ir)
166+
set_layer_name(topk_layer, target, f"{name}_topk", source_ir)
161167

162168
if return_indices:
163169
return topk_layer.get_output(0), topk_layer.get_output(1)

tests/py/dynamo/conversion/test_any.py

Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import torch.nn as nn
33
from parameterized import parameterized
44
from torch.testing._internal.common_utils import run_tests
5+
from torch_tensorrt import Input
56

67
from .harness import DispatchTestCase
78

@@ -184,5 +185,152 @@ def forward(self, x):
184185
)
185186

186187

188+
class TestAnyConverterDynamic(DispatchTestCase):
189+
@parameterized.expand(
190+
[
191+
(
192+
"3d_dynamic_float",
193+
(2, 1, 1),
194+
(2, 2, 1),
195+
(3, 2, 4),
196+
torch.float,
197+
),
198+
(
199+
"2d_dynamic_int32",
200+
(2, 2),
201+
(2, 2),
202+
(3, 2),
203+
torch.int32,
204+
),
205+
(
206+
"4d_dynamic_bool",
207+
(1, 2, 1, 1),
208+
(2, 2, 2, 2),
209+
(2, 2, 4, 3),
210+
torch.bool,
211+
),
212+
]
213+
)
214+
def test_any_dynamic(self, _, min_shape, opt_shape, max_shape, type):
215+
class Any(nn.Module):
216+
def forward(self, x):
217+
return torch.ops.aten.any.default(x)
218+
219+
input_specs = [
220+
Input(
221+
min_shape=min_shape,
222+
opt_shape=opt_shape,
223+
max_shape=max_shape,
224+
dtype=type,
225+
),
226+
]
227+
self.run_test_with_dynamic_shape(
228+
Any(),
229+
input_specs,
230+
)
231+
232+
@parameterized.expand(
233+
[
234+
(
235+
"3d_dynamic_dim_float",
236+
(2, 1, 1),
237+
(2, 2, 1),
238+
(3, 2, 4),
239+
torch.float,
240+
2,
241+
True,
242+
),
243+
(
244+
"4d_dynamic_dim_int32",
245+
(1, 1, 4, 1),
246+
(2, 2, 4, 2),
247+
(2, 4, 4, 3),
248+
torch.int32,
249+
-2,
250+
False,
251+
),
252+
(
253+
"3d_dynamic_dim_bool",
254+
(2, 1, 1),
255+
(2, 2, 1),
256+
(3, 2, 4),
257+
torch.bool,
258+
0,
259+
True,
260+
),
261+
]
262+
)
263+
def test_any_dynamic_dim(
264+
self, _, min_shape, opt_shape, max_shape, type, dim, keep_dims
265+
):
266+
class AnyDim(nn.Module):
267+
def forward(self, x):
268+
return torch.ops.aten.any.dim(x, dim, keep_dims)
269+
270+
input_specs = [
271+
Input(
272+
min_shape=min_shape,
273+
opt_shape=opt_shape,
274+
max_shape=max_shape,
275+
dtype=type,
276+
),
277+
]
278+
self.run_test_with_dynamic_shape(
279+
AnyDim(),
280+
input_specs,
281+
)
282+
283+
@parameterized.expand(
284+
[
285+
(
286+
"3d_dynamic_dims_float",
287+
(2, 1, 1),
288+
(2, 2, 1),
289+
(3, 2, 4),
290+
torch.float,
291+
[1, 2],
292+
True,
293+
),
294+
(
295+
"4d_dynamic_dims_int32",
296+
(1, 1, 4, 1),
297+
(2, 2, 4, 2),
298+
(2, 4, 4, 3),
299+
torch.int32,
300+
[2, -1],
301+
False,
302+
),
303+
(
304+
"3d_dynamic_dims_bool",
305+
(1, 4, 1),
306+
(2, 4, 2),
307+
(4, 4, 3),
308+
torch.bool,
309+
[0, 1, 2],
310+
False,
311+
),
312+
]
313+
)
314+
def test_any_dynamic_dims(
315+
self, _, min_shape, opt_shape, max_shape, type, dims, keep_dims
316+
):
317+
class AnyDims(nn.Module):
318+
def forward(self, x):
319+
return torch.ops.aten.any.dims(x, dims, keep_dims)
320+
321+
input_specs = [
322+
Input(
323+
min_shape=min_shape,
324+
opt_shape=opt_shape,
325+
max_shape=max_shape,
326+
dtype=type,
327+
),
328+
]
329+
self.run_test_with_dynamic_shape(
330+
AnyDims(),
331+
input_specs,
332+
)
333+
334+
187335
if __name__ == "__main__":
188336
run_tests()

tests/py/dynamo/conversion/test_sort_aten.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import torch.nn as nn
33
from parameterized import parameterized
44
from torch.testing._internal.common_utils import run_tests
5+
from torch_tensorrt import Input
56

67
from .harness import DispatchTestCase
78

@@ -32,5 +33,55 @@ def forward(self, x):
3233
)
3334

3435

36+
class TestSortConverterDynamic(DispatchTestCase):
37+
@parameterized.expand(
38+
[
39+
(
40+
"3d_dynamic_descending",
41+
(2, 1, 4),
42+
(3, 2, 4),
43+
(3, 3, 4),
44+
2,
45+
True,
46+
),
47+
(
48+
"4d_dynamic_ascending",
49+
(2, 2, 1, 4),
50+
(2, 2, 2, 4),
51+
(3, 3, 2, 4),
52+
3,
53+
False,
54+
),
55+
(
56+
"4d_dynamic_descending_neg_dim",
57+
(1, 3, 1, 1),
58+
(2, 3, 2, 2),
59+
(3, 3, 2, 4),
60+
-3,
61+
True,
62+
),
63+
]
64+
)
65+
def test_sort_dynamic(self, _, min_shape, opt_shape, max_shape, dim, descending):
66+
class Sort(nn.Module):
67+
def forward(self, x):
68+
return torch.ops.aten.sort.default(x, dim, descending)
69+
70+
input_specs = [
71+
Input(
72+
min_shape=min_shape,
73+
opt_shape=opt_shape,
74+
max_shape=max_shape,
75+
dtype=torch.float,
76+
),
77+
]
78+
self.run_test_with_dynamic_shape(
79+
Sort(),
80+
input_specs,
81+
output_dtypes=[torch.float, torch.int64],
82+
use_dynamo_tracer=True,
83+
)
84+
85+
3586
if __name__ == "__main__":
3687
run_tests()

tests/py/dynamo/conversion/test_trunc_aten.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import torch.nn as nn
33
from parameterized import parameterized
44
from torch.testing._internal.common_utils import run_tests
5+
from torch_tensorrt import Input
56

67
from .harness import DispatchTestCase
78

@@ -48,5 +49,48 @@ def forward(self, input):
4849
)
4950

5051

52+
class TestTruncConverterDynamic(DispatchTestCase):
53+
@parameterized.expand(
54+
[
55+
(
56+
"3d_dynamic_int32",
57+
(1, 1, 1),
58+
(2, 2, 2),
59+
(3, 4, 5),
60+
torch.int32,
61+
False,
62+
),
63+
(
64+
"3d_dynamic_float32",
65+
(2, 1, 1),
66+
(2, 2, 2),
67+
(2, 4, 5),
68+
torch.float32,
69+
True,
70+
),
71+
]
72+
)
73+
def test_trunc_dynamic(
74+
self, _, min_shape, opt_shape, max_shape, type, enable_passes
75+
):
76+
class Trunc(nn.Module):
77+
def forward(self, input):
78+
return torch.ops.aten.trunc.default(input)
79+
80+
input_specs = [
81+
Input(
82+
min_shape=min_shape,
83+
opt_shape=opt_shape,
84+
max_shape=max_shape,
85+
dtype=type,
86+
),
87+
]
88+
self.run_test_with_dynamic_shape(
89+
Trunc(),
90+
input_specs,
91+
enable_passes=enable_passes,
92+
)
93+
94+
5195
if __name__ == "__main__":
5296
run_tests()

0 commit comments

Comments
 (0)