Skip to content

Commit c04a9da

Browse files
authored
feat: support torch.ops.aten.sum.(default and dim_IntList) dynamo converter (#2278)
1 parent d6a07bb commit c04a9da

File tree

3 files changed

+161
-1
lines changed

3 files changed

+161
-1
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -587,6 +587,26 @@ def aten_ops_amax(
587587
)
588588

589589

590+
@dynamo_tensorrt_converter(torch.ops.aten.sum.default)
591+
@dynamo_tensorrt_converter(torch.ops.aten.sum.dim_IntList)
592+
def aten_ops_sum(
593+
network: TRTNetwork,
594+
target: Target,
595+
args: Tuple[Argument, ...],
596+
kwargs: Dict[str, Argument],
597+
name: str,
598+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
599+
return impl.reduce.sum(
600+
network,
601+
target,
602+
SourceIR.ATEN,
603+
name,
604+
args[0],
605+
args_bounds_check(args, 1, replacement=None),
606+
args_bounds_check(args, 2, replacement=False),
607+
)
608+
609+
590610
@dynamo_tensorrt_converter(torch.ops.aten.exp.default) # type: ignore[misc]
591611
def aten_ops_exp(
592612
network: TRTNetwork,

py/torch_tensorrt/dynamo/conversion/impl/reduce.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Optional, Tuple, Union
1+
from typing import Optional, Sequence, Tuple, Union
22

33
import tensorrt as trt
44
from torch.fx.node import Target
@@ -33,3 +33,29 @@ def amax(
3333
)
3434
set_layer_name(layer, target, name, source_ir)
3535
return layer.get_output(0)
36+
37+
38+
def sum(
39+
network: TRTNetwork,
40+
target: Target,
41+
source_ir: Optional[SourceIR],
42+
name: str,
43+
input_val: TRTTensor,
44+
dim: Optional[Union[int, Sequence[int]]] = None,
45+
keepdim: bool = False,
46+
) -> TRTTensor:
47+
if (isinstance(input_val, TRTTensor)) and (
48+
input_val.dtype == trt.int8 or input_val.dtype == trt.int32
49+
):
50+
input_val = cast_trt_tensor(network, input_val, trt.float32, name)
51+
52+
if dim is None:
53+
dim = tuple(range(len(input_val.shape)))
54+
layer = network.add_reduce(
55+
input_val,
56+
trt.ReduceOperation.SUM,
57+
axes=get_axes_for_reduce_op(dim),
58+
keep_dims=keepdim,
59+
)
60+
set_layer_name(layer, target, name, source_ir)
61+
return layer.get_output(0)
Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
import torch
2+
import torch.nn as nn
3+
from parameterized import parameterized
4+
from torch.testing._internal.common_utils import run_tests
5+
6+
from .harness import DispatchTestCase
7+
8+
9+
class TestSumConverter(DispatchTestCase):
10+
@parameterized.expand(
11+
[
12+
((3, 2, 4),),
13+
((2, 3, 4, 5),),
14+
((2, 3, 4, 5),),
15+
((6, 7, 5, 4, 5),),
16+
]
17+
)
18+
def test_sum_dim_int_default(self, input_shape):
19+
class Sum(nn.Module):
20+
def forward(self, x):
21+
return torch.sum(x)
22+
23+
inputs = [torch.randn(*input_shape)]
24+
self.run_test(
25+
Sum(),
26+
inputs,
27+
expected_ops={torch.ops.aten.sum.default},
28+
)
29+
30+
@parameterized.expand(
31+
[
32+
((3, 2, 4), 1, True),
33+
((2, 3, 4, 5), 3, True),
34+
((2, 3, 4, 5), None, False),
35+
((6, 7, 5, 4, 5), 4, False),
36+
]
37+
)
38+
def test_sum_dim_int(self, input_shape, dim, keep_dims):
39+
class Sum(nn.Module):
40+
def forward(self, x):
41+
return torch.sum(x, dim=dim, keepdim=keep_dims)
42+
43+
inputs = [torch.randn(*input_shape)]
44+
self.run_test(
45+
Sum(),
46+
inputs,
47+
expected_ops={torch.ops.aten.sum.dim_IntList},
48+
)
49+
50+
@parameterized.expand(
51+
[
52+
((3, 2, 4), [1], True),
53+
((2, 1, 4, 5), None, True),
54+
((2, 3, 4, 5), [0, 1, 2, 3], False),
55+
((6, 7, 5, 4, 5), [1, 3, 4], False),
56+
]
57+
)
58+
def test_sum_dim_tuple(self, input_shape, dim, keep_dims):
59+
class Sum(nn.Module):
60+
def forward(self, x):
61+
return torch.sum(x, dim=dim, keepdim=keep_dims)
62+
63+
inputs = [torch.randn(*input_shape)]
64+
self.run_test(
65+
Sum(),
66+
inputs,
67+
expected_ops={torch.ops.aten.sum.dim_IntList},
68+
)
69+
70+
@parameterized.expand(
71+
[
72+
((3, 2, 4), 1, True, torch.int, 0, 5),
73+
((2, 3, 4, 5), None, True, torch.int, -10, 10),
74+
((2, 3, 4, 5), 2, False, torch.int32, -5, 0),
75+
((6, 7, 5, 4, 5), 4, False, torch.int32, -5, 5),
76+
]
77+
)
78+
def test_sum_dim_int_int(self, input_shape, dim, keep_dims, dtype, low, high):
79+
class Sum(nn.Module):
80+
def forward(self, x):
81+
return torch.sum(x, dim=dim, keepdim=keep_dims)
82+
83+
inputs = [torch.randint(low, high, input_shape, dtype=dtype)]
84+
self.run_test(
85+
Sum(),
86+
inputs,
87+
expected_ops={torch.ops.aten.sum.dim_IntList},
88+
check_dtype=False,
89+
)
90+
91+
@parameterized.expand(
92+
[
93+
((3, 2, 4), [1], True, torch.int, 0, 5),
94+
((2, 1, 4, 5), [0, 3], True, torch.int, -10, 10),
95+
((2, 3, 4, 5), None, False, torch.int32, -5, 0),
96+
((6, 7, 5, 4, 5), [1, 3, 4], False, torch.int32, -5, 5),
97+
]
98+
)
99+
def test_sum_dim_tuple_int(self, input_shape, dim, keep_dims, dtype, low, high):
100+
class Sum(nn.Module):
101+
def forward(self, x):
102+
return torch.sum(x, dim=dim, keepdim=keep_dims)
103+
104+
inputs = [torch.randint(low, high, input_shape, dtype=dtype)]
105+
self.run_test(
106+
Sum(),
107+
inputs,
108+
expected_ops={torch.ops.aten.sum.dim_IntList},
109+
check_dtype=False,
110+
)
111+
112+
113+
if __name__ == "__main__":
114+
run_tests()

0 commit comments

Comments
 (0)