Skip to content

Commit 46cfa35

Browse files
authored
fix: Add support for negative dimensions in reduce (#2347)
1 parent 253bbd1 commit 46cfa35

File tree

10 files changed

+72
-35
lines changed

10 files changed

+72
-35
lines changed

py/torch_tensorrt/dynamo/conversion/converter_utils.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import functools
22
import logging
33
import re
4-
from typing import Any, Callable, List, Optional, Tuple, Union
4+
from typing import Any, Callable, List, Optional, Sequence, Tuple, Union, overload
55

66
import numpy as np
77
import tensorrt as trt
@@ -314,3 +314,41 @@ def get_trt_tensor(
314314
return input_val
315315
else:
316316
raise AssertionError(f"Cannot convert {input_val} to TRT constant")
317+
318+
319+
@overload
320+
def get_positive_dim(dim: int, dim_size: int) -> int:
321+
...
322+
323+
324+
@overload
325+
def get_positive_dim(dim: Sequence[int], dim_size: int) -> Tuple[int, ...]:
326+
...
327+
328+
329+
def get_positive_dim(
330+
dim: Union[int, Sequence[int]], dim_size: int
331+
) -> Union[int, Tuple[int, ...]]:
332+
"""
333+
Given an integer number or tuple that represents dimension(s) in the array,
334+
transform it to a positive integer dim if it's negative. Otherwise, do
335+
nothing.
336+
337+
Args:
338+
dim (Union[int, Sequence[int]]): A integer or Sequence of integers that represent dimension(s) in an array.
339+
dim_size (int): The size of the dimension in the array.
340+
341+
Returns:
342+
A positive integer or tuple of integers that represent the same dimension as the given dim.
343+
"""
344+
345+
def positive_dim(d: int) -> int:
346+
if d < 0:
347+
return d % dim_size
348+
return d
349+
350+
return (
351+
positive_dim(dim)
352+
if isinstance(dim, int)
353+
else tuple(positive_dim(d) for d in dim)
354+
)

py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,12 @@
66
import torch
77
from torch.fx.node import Target
88
from torch_tensorrt.dynamo._SourceIR import SourceIR
9+
from torch_tensorrt.dynamo.conversion.converter_utils import get_positive_dim
910
from torch_tensorrt.dynamo.conversion.impl.elementwise.base import (
1011
convert_binary_elementwise,
1112
)
1213
from torch_tensorrt.dynamo.conversion.impl.unary.base import convert_unary
1314
from torch_tensorrt.fx.converters.converter_utils import (
14-
get_positive_dim,
1515
get_trt_plugin,
1616
has_dynamic_shape,
1717
set_layer_name,

py/torch_tensorrt/dynamo/conversion/impl/permutation.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,8 @@
22

33
from torch.fx.node import Target
44
from torch_tensorrt.dynamo._SourceIR import SourceIR
5-
from torch_tensorrt.fx.converters.converter_utils import (
6-
get_positive_dim,
7-
set_layer_name,
8-
)
5+
from torch_tensorrt.dynamo.conversion.converter_utils import get_positive_dim
6+
from torch_tensorrt.fx.converters.converter_utils import set_layer_name
97
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor
108

119

@@ -22,7 +20,7 @@ def permute(
2220
f"permute received input {input} that is not a TensorRT ITensor"
2321
)
2422

25-
permutation = [get_positive_dim(i, len(input.shape)) for i in permutation]
23+
permutation = get_positive_dim(permutation, len(input.shape))
2624

2725
layer = network.add_shuffle(input)
2826
layer.second_transpose = tuple(permutation)

py/torch_tensorrt/dynamo/conversion/impl/reduce.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
1-
from typing import Optional, Sequence, Tuple, Union
1+
from typing import Optional, Sequence, Union
22

33
import tensorrt as trt
44
from torch.fx.node import Target
55
from torch_tensorrt.dynamo._SourceIR import SourceIR
66
from torch_tensorrt.dynamo.conversion.converter_utils import (
77
cast_trt_tensor,
88
get_axes_for_reduce_op,
9+
get_positive_dim,
910
)
1011
from torch_tensorrt.fx.converters.converter_utils import set_layer_name
1112
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor
@@ -17,7 +18,7 @@ def amax(
1718
source_ir: Optional[SourceIR],
1819
name: str,
1920
input_val: TRTTensor,
20-
dim: Union[int, Tuple[int]],
21+
dim: Union[int, Sequence[int]],
2122
keepdim: bool = False,
2223
) -> TRTTensor:
2324
if (isinstance(input_val, TRTTensor)) and (
@@ -28,7 +29,7 @@ def amax(
2829
layer = network.add_reduce(
2930
input_val,
3031
trt.ReduceOperation.MAX,
31-
axes=get_axes_for_reduce_op(dim),
32+
axes=get_axes_for_reduce_op(get_positive_dim(dim, len(input_val.shape))),
3233
keep_dims=keepdim,
3334
)
3435
set_layer_name(layer, target, name, source_ir)
@@ -54,7 +55,7 @@ def sum(
5455
layer = network.add_reduce(
5556
input_val,
5657
trt.ReduceOperation.SUM,
57-
axes=get_axes_for_reduce_op(dim),
58+
axes=get_axes_for_reduce_op(get_positive_dim(dim, len(input_val.shape))),
5859
keep_dims=keepdim,
5960
)
6061
set_layer_name(layer, target, name, source_ir)

py/torch_tensorrt/dynamo/conversion/impl/select.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,9 @@
33
import numpy as np
44
from torch.fx.node import Target
55
from torch_tensorrt.dynamo._SourceIR import SourceIR
6+
from torch_tensorrt.dynamo.conversion.converter_utils import get_positive_dim
67
from torch_tensorrt.dynamo.conversion.impl.shape import get_shape_with_dynamic_shape
7-
from torch_tensorrt.fx.converters.converter_utils import (
8-
get_positive_dim,
9-
has_dynamic_shape,
10-
to_numpy,
11-
)
8+
from torch_tensorrt.fx.converters.converter_utils import has_dynamic_shape, to_numpy
129
from torch_tensorrt.fx.types import Shape, TRTNetwork, TRTTensor
1310

1411

py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33

44
from torch.fx.node import Target
55
from torch_tensorrt.dynamo._SourceIR import SourceIR
6+
from torch_tensorrt.dynamo.conversion.converter_utils import get_positive_dim
67
from torch_tensorrt.dynamo.conversion.impl.slice.base import slice
78
from torch_tensorrt.fx.converters.converter_utils import (
8-
get_positive_dim,
99
has_dynamic_shape,
1010
prepend_ones,
1111
set_layer_name,

py/torch_tensorrt/dynamo/conversion/impl/squeeze.py

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
1-
from typing import Any, Optional, cast
1+
from typing import Optional, Sequence, Union
22

33
from torch.fx.node import Target
44
from torch_tensorrt.dynamo._SourceIR import SourceIR
5-
from torch_tensorrt.fx.converters.converter_utils import (
6-
get_positive_dim,
7-
set_layer_name,
8-
)
5+
from torch_tensorrt.dynamo.conversion.converter_utils import get_positive_dim
6+
from torch_tensorrt.fx.converters.converter_utils import set_layer_name
97
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor
108
from torch_tensorrt.fx.utils import get_dynamic_dims
119

@@ -16,19 +14,18 @@ def squeeze(
1614
source_ir: Optional[SourceIR],
1715
name: str,
1816
input: TRTTensor,
19-
dim: Optional[Any] = None,
17+
dim: Optional[Union[int, Sequence[int]]] = None,
2018
) -> TRTTensor:
21-
dims = []
22-
if dim is not None:
23-
if isinstance(dim, int):
24-
dims.append(cast(Optional[int], dim))
25-
else:
26-
for dim in dim:
27-
dims.append(cast(Optional[int], dim))
28-
2919
# Squeeze with dim=None would only work in explicit batch dim mode without any dynamic
3020
# dim, which is a very rare case. For now we just claim not supporting dim=None.
31-
assert not (len(dims) == 0), "We don't support dim=None right now for squeeze."
21+
assert dim is not None, "We don't support dim=None right now for squeeze."
22+
dims = []
23+
24+
if isinstance(dim, int):
25+
dims.append(dim)
26+
else:
27+
for dim in dim:
28+
dims.append(dim)
3229

3330
new_dims = []
3431
for dim in dims:

py/torch_tensorrt/dynamo/conversion/impl/unsqueeze.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,11 @@
22

33
from torch.fx.node import Target
44
from torch_tensorrt.dynamo._SourceIR import SourceIR
5-
from torch_tensorrt.dynamo.conversion.converter_utils import get_trt_tensor
6-
from torch_tensorrt.fx.converters.converter_utils import (
5+
from torch_tensorrt.dynamo.conversion.converter_utils import (
76
get_positive_dim,
8-
set_layer_name,
7+
get_trt_tensor,
98
)
9+
from torch_tensorrt.fx.converters.converter_utils import set_layer_name
1010
from torch_tensorrt.fx.types import Shape, TRTNetwork, TRTTensor
1111
from torch_tensorrt.fx.utils import get_dynamic_dims
1212

tests/py/dynamo/conversion/test_amax_aten.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ class TestAmaxConverter(DispatchTestCase):
1313
((2, 3, 4, 5), 3, True),
1414
((2, 3, 4, 5), 2, False),
1515
((6, 7, 5, 4, 5), 4, False),
16+
((1, 5, 2, 1), -1, True),
1617
]
1718
)
1819
def test_amax_dim_int_default(self, input_shape, dim, keep_dims):
@@ -53,6 +54,7 @@ def forward(self, x):
5354
((2, 3, 4, 5), 3, True, torch.int, -10, 10),
5455
((2, 3, 4, 5), 2, False, torch.int32, -5, 0),
5556
((6, 7, 5, 4, 5), 4, False, torch.int32, -5, 5),
57+
((1, 5, 2, 1), -4, False, torch.int32, -5, 5),
5658
]
5759
)
5860
def test_amax_dim_int_int(self, input_shape, dim, keep_dims, dtype, low, high):
@@ -74,6 +76,7 @@ def forward(self, x):
7476
((2, 1, 4, 5), [0, 3], True, torch.int, -10, 10),
7577
((2, 3, 4, 5), [0, 1, 2, 3], False, torch.int32, -5, 0),
7678
((6, 7, 5, 4, 5), [1, 3, 4], False, torch.int32, -5, 5),
79+
((1, 5, 2, 1), [-3, -1], False, torch.int32, -5, 5),
7780
]
7881
)
7982
def test_amax_dim_tuple_int(self, input_shape, dim, keep_dims, dtype, low, high):

tests/py/dynamo/conversion/test_sum_aten.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ def forward(self, x):
3333
((2, 3, 4, 5), 3, True),
3434
((2, 3, 4, 5), None, False),
3535
((6, 7, 5, 4, 5), 4, False),
36+
((1, 5, 2, 1), -3, False),
37+
((1, 5, 2, 3), -2, True),
3638
]
3739
)
3840
def test_sum_dim_int(self, input_shape, dim, keep_dims):
@@ -53,6 +55,7 @@ def forward(self, x):
5355
((2, 1, 4, 5), None, True),
5456
((2, 3, 4, 5), [0, 1, 2, 3], False),
5557
((6, 7, 5, 4, 5), [1, 3, 4], False),
58+
((6, 7, 5, 4, 5), [-5, -4, -2], False),
5659
]
5760
)
5861
def test_sum_dim_tuple(self, input_shape, dim, keep_dims):

0 commit comments

Comments
 (0)