Skip to content

Commit 80bbd8b

Browse files
authored
feat: support prod, max, min, and mean via reduce layer (#2355)
1 parent 65e8ec7 commit 80bbd8b

File tree

9 files changed

+460
-28
lines changed

9 files changed

+460
-28
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 97 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import logging
2+
import operator
23
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union
34

45
import numpy as np
@@ -155,12 +156,12 @@ def aten_ops_sigmoid(
155156
)
156157

157158

158-
@dynamo_tensorrt_converter(torch.ops.aten.index.Tensor)
159+
@dynamo_tensorrt_converter(torch.ops.aten.index.Tensor) # type: ignore[misc]
159160
@enforce_tensor_types(
160161
{
161162
0: (TRTTensor,),
162163
}
163-
)
164+
) # type: ignore[misc]
164165
def aten_ops_index(
165166
ctx: ConversionContext,
166167
target: Target,
@@ -685,7 +686,7 @@ def aten_ops_amax(
685686
SourceIR.ATEN,
686687
name,
687688
args[0],
688-
args[1],
689+
args_bounds_check(args, 1, replacement=[]),
689690
args_bounds_check(args, 2, replacement=False),
690691
)
691692

@@ -724,6 +725,97 @@ def aten_ops_sum(
724725
return sum_
725726

726727

728+
@dynamo_tensorrt_converter(torch.ops.aten.prod.default) # type: ignore[misc]
729+
@dynamo_tensorrt_converter(torch.ops.aten.prod.dim_int) # type: ignore[misc]
730+
def aten_ops_prod(
731+
ctx: ConversionContext,
732+
target: Target,
733+
args: Tuple[Argument, ...],
734+
kwargs: Dict[str, Argument],
735+
name: str,
736+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
737+
return impl.reduce.prod(
738+
ctx,
739+
target,
740+
SourceIR.ATEN,
741+
name,
742+
args[0],
743+
args_bounds_check(args, 1, replacement=None),
744+
args_bounds_check(args, 2, replacement=False),
745+
)
746+
747+
748+
def one_user_validator(node: Node) -> bool:
749+
# Validate only one user, which is a getitem node that accesses the first element in the list
750+
return (
751+
len(node.users) == 1
752+
and list(node.users)[0].target == operator.getitem
753+
and list(node.users)[0].args[1] == 0
754+
)
755+
756+
757+
@dynamo_tensorrt_converter(torch.ops.aten.max.default) # type: ignore[misc]
758+
@dynamo_tensorrt_converter(torch.ops.aten.max.dim, capability_validator=one_user_validator) # type: ignore[misc]
759+
def aten_ops_max(
760+
ctx: ConversionContext,
761+
target: Target,
762+
args: Tuple[Argument, ...],
763+
kwargs: Dict[str, Argument],
764+
name: str,
765+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
766+
return impl.reduce.max(
767+
ctx,
768+
target,
769+
SourceIR.ATEN,
770+
name,
771+
args[0],
772+
dim=args_bounds_check(args, 1, replacement=None),
773+
keepdim=args_bounds_check(args, 2, replacement=False),
774+
return_indices=(target == torch.ops.aten.max.dim),
775+
)
776+
777+
778+
@dynamo_tensorrt_converter(torch.ops.aten.min.default) # type: ignore[misc]
779+
@dynamo_tensorrt_converter(torch.ops.aten.min.dim, capability_validator=one_user_validator) # type: ignore[misc]
780+
def aten_ops_min(
781+
ctx: ConversionContext,
782+
target: Target,
783+
args: Tuple[Argument, ...],
784+
kwargs: Dict[str, Argument],
785+
name: str,
786+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
787+
return impl.reduce.min(
788+
ctx,
789+
target,
790+
SourceIR.ATEN,
791+
name,
792+
args[0],
793+
dim=args_bounds_check(args, 1, replacement=None),
794+
keepdim=args_bounds_check(args, 2, replacement=False),
795+
return_indices=(target == torch.ops.aten.min.dim),
796+
)
797+
798+
799+
@dynamo_tensorrt_converter(torch.ops.aten.mean.default) # type: ignore[misc]
800+
@dynamo_tensorrt_converter(torch.ops.aten.mean.dim) # type: ignore[misc]
801+
def aten_ops_mean(
802+
ctx: ConversionContext,
803+
target: Target,
804+
args: Tuple[Argument, ...],
805+
kwargs: Dict[str, Argument],
806+
name: str,
807+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
808+
return impl.reduce.mean(
809+
ctx,
810+
target,
811+
SourceIR.ATEN,
812+
name,
813+
args[0],
814+
args_bounds_check(args, 1, replacement=None),
815+
args_bounds_check(args, 2, replacement=False),
816+
)
817+
818+
727819
@dynamo_tensorrt_converter(torch.ops.aten.exp.default) # type: ignore[misc]
728820
def aten_ops_exp(
729821
ctx: ConversionContext,
@@ -1150,7 +1242,7 @@ def aten_ops_mul(
11501242

11511243

11521244
@dynamo_tensorrt_converter(torch.ops.aten.maximum.default) # type: ignore[misc]
1153-
def aten_ops_max(
1245+
def aten_ops_maximum(
11541246
ctx: ConversionContext,
11551247
target: Target,
11561248
args: Tuple[Argument, ...],
@@ -1168,7 +1260,7 @@ def aten_ops_max(
11681260

11691261

11701262
@dynamo_tensorrt_converter(torch.ops.aten.minimum.default) # type: ignore[misc]
1171-
def aten_ops_min(
1263+
def aten_ops_minimum(
11721264
ctx: ConversionContext,
11731265
target: Target,
11741266
args: Tuple[Argument, ...],

py/torch_tensorrt/dynamo/conversion/impl/reduce.py

Lines changed: 125 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Optional, Sequence, Union
1+
from typing import Optional, Sequence, Tuple, Union
22

33
import tensorrt as trt
44
from torch.fx.node import Target
@@ -27,6 +27,9 @@ def amax(
2727
):
2828
input_val = cast_trt_tensor(ctx, input_val, trt.float32, name)
2929

30+
if dim is None or (isinstance(dim, (tuple, list)) and len(dim) == 0):
31+
dim = tuple(range(len(input_val.shape)))
32+
3033
layer = ctx.net.add_reduce(
3134
input_val,
3235
trt.ReduceOperation.MAX,
@@ -43,8 +46,8 @@ def sum(
4346
source_ir: Optional[SourceIR],
4447
name: str,
4548
input_val: TRTTensor,
46-
dim: Optional[Union[int, Sequence[int]]] = None,
47-
keepdim: bool = False,
49+
dim: Optional[Union[int, Sequence[int]]],
50+
keepdim: bool,
4851
) -> TRTTensor:
4952
if (isinstance(input_val, TRTTensor)) and (
5053
input_val.dtype == trt.int8 or input_val.dtype == trt.int32
@@ -53,6 +56,7 @@ def sum(
5356

5457
if dim is None or (isinstance(dim, (tuple, list)) and len(dim) == 0):
5558
dim = tuple(range(len(input_val.shape)))
59+
5660
layer = ctx.net.add_reduce(
5761
input_val,
5862
trt.ReduceOperation.SUM,
@@ -61,3 +65,121 @@ def sum(
6165
)
6266
set_layer_name(layer, target, name, source_ir)
6367
return layer.get_output(0)
68+
69+
70+
def prod(
71+
ctx: ConversionContext,
72+
target: Target,
73+
source_ir: Optional[SourceIR],
74+
name: str,
75+
input_val: TRTTensor,
76+
dim: Optional[Union[int, Sequence[int]]],
77+
keepdim: bool,
78+
) -> TRTTensor:
79+
if (isinstance(input_val, TRTTensor)) and (
80+
input_val.dtype == trt.int8 or input_val.dtype == trt.int32
81+
):
82+
input_val = cast_trt_tensor(ctx, input_val, trt.float32, name)
83+
84+
if dim is None:
85+
dim = tuple(range(len(input_val.shape)))
86+
87+
layer = ctx.net.add_reduce(
88+
input_val,
89+
trt.ReduceOperation.PROD,
90+
axes=get_axes_for_reduce_op(get_positive_dim(dim, len(input_val.shape))),
91+
keep_dims=keepdim,
92+
)
93+
set_layer_name(layer, target, name, source_ir)
94+
return layer.get_output(0)
95+
96+
97+
def max(
98+
ctx: ConversionContext,
99+
target: Target,
100+
source_ir: Optional[SourceIR],
101+
name: str,
102+
input_val: TRTTensor,
103+
dim: Optional[Union[int, Sequence[int]]],
104+
keepdim: bool,
105+
return_indices: bool,
106+
) -> Union[TRTTensor, Tuple[TRTTensor, TRTTensor]]:
107+
if (isinstance(input_val, TRTTensor)) and (
108+
input_val.dtype == trt.int8 or input_val.dtype == trt.int32
109+
):
110+
input_val = cast_trt_tensor(ctx, input_val, trt.float32, name)
111+
112+
if dim is None:
113+
dim = tuple(range(len(input_val.shape)))
114+
115+
layer = ctx.net.add_reduce(
116+
input_val,
117+
trt.ReduceOperation.MAX,
118+
axes=get_axes_for_reduce_op(get_positive_dim(dim, len(input_val.shape))),
119+
keep_dims=keepdim,
120+
)
121+
set_layer_name(layer, target, name, source_ir)
122+
123+
if return_indices:
124+
return layer.get_output(0), None
125+
126+
return layer.get_output(0)
127+
128+
129+
def min(
130+
ctx: ConversionContext,
131+
target: Target,
132+
source_ir: Optional[SourceIR],
133+
name: str,
134+
input_val: TRTTensor,
135+
dim: Optional[Union[int, Sequence[int]]],
136+
keepdim: bool,
137+
return_indices: bool,
138+
) -> Union[TRTTensor, Tuple[TRTTensor, TRTTensor]]:
139+
if (isinstance(input_val, TRTTensor)) and (
140+
input_val.dtype == trt.int8 or input_val.dtype == trt.int32
141+
):
142+
input_val = cast_trt_tensor(ctx, input_val, trt.float32, name)
143+
144+
if dim is None:
145+
dim = tuple(range(len(input_val.shape)))
146+
147+
layer = ctx.net.add_reduce(
148+
input_val,
149+
trt.ReduceOperation.MIN,
150+
axes=get_axes_for_reduce_op(get_positive_dim(dim, len(input_val.shape))),
151+
keep_dims=keepdim,
152+
)
153+
set_layer_name(layer, target, name, source_ir)
154+
155+
if return_indices:
156+
return layer.get_output(0), None
157+
158+
return layer.get_output(0)
159+
160+
161+
def mean(
162+
ctx: ConversionContext,
163+
target: Target,
164+
source_ir: Optional[SourceIR],
165+
name: str,
166+
input_val: TRTTensor,
167+
dim: Optional[Union[int, Sequence[int]]],
168+
keepdim: bool,
169+
) -> TRTTensor:
170+
if (isinstance(input_val, TRTTensor)) and (
171+
input_val.dtype == trt.int8 or input_val.dtype == trt.int32
172+
):
173+
input_val = cast_trt_tensor(ctx, input_val, trt.float32, name)
174+
175+
if dim is None or (isinstance(dim, (tuple, list)) and len(dim) == 0):
176+
dim = tuple(range(len(input_val.shape)))
177+
178+
layer = ctx.net.add_reduce(
179+
input_val,
180+
trt.ReduceOperation.AVG,
181+
axes=get_axes_for_reduce_op(get_positive_dim(dim, len(input_val.shape))),
182+
keep_dims=keepdim,
183+
)
184+
set_layer_name(layer, target, name, source_ir)
185+
return layer.get_output(0)

tests/py/dynamo/conversion/test_amax_aten.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ def forward(self, x):
2929

3030
@parameterized.expand(
3131
[
32+
((1, 2, 4), [], True),
3233
((3, 2, 4), [1], True),
3334
((2, 1, 4, 5), [0, 3], True),
3435
((2, 3, 4, 5), [0, 1, 2, 3], False),
@@ -69,6 +70,7 @@ def forward(self, x):
6970

7071
@parameterized.expand(
7172
[
73+
((1, 2, 4), [], True, torch.int, 0, 5),
7274
((3, 2, 4), [1], True, torch.int, 0, 5),
7375
((2, 1, 4, 5), [0, 3], True, torch.int, -10, 10),
7476
((2, 3, 4, 5), [0, 1, 2, 3], False, torch.int32, -5, 0),

tests/py/dynamo/conversion/test_max_aten.py

Lines changed: 49 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,28 +2,67 @@
22
import torch.nn as nn
33
from parameterized import parameterized
44
from torch.testing._internal.common_utils import run_tests
5-
from torch_tensorrt import Input
65

76
from .harness import DispatchTestCase
87

98

109
class TestMaxConverter(DispatchTestCase):
1110
@parameterized.expand(
1211
[
13-
("2d", (2, 1)),
14-
("3d", (2, 1, 2)),
12+
((1, 2),),
13+
((3, 2, 4),),
14+
((2, 3, 4, 5),),
15+
((6, 7, 5, 4, 5),),
1516
]
1617
)
17-
def test_max(self, _, shape):
18-
class max(nn.Module):
19-
def forward(self, lhs_val, rhs_val):
20-
return torch.ops.aten.maximum.default(lhs_val, rhs_val)
18+
def test_max_dim_int_default(self, input_shape):
19+
class Max(nn.Module):
20+
def forward(self, x):
21+
return torch.ops.aten.max.default(x)
2122

22-
inputs = [torch.randn(shape), torch.randn(shape)]
23+
inputs = [torch.randn(*input_shape)]
2324
self.run_test(
24-
max(),
25+
Max(),
2526
inputs,
26-
# expected_ops={torch.ops.aten.maximum.default},
27+
)
28+
29+
@parameterized.expand(
30+
[
31+
((3, 2, 4), 1, True),
32+
((2, 3, 4, 5), 3, True),
33+
((6, 7, 5, 4, 5), 4, False),
34+
((1, 5, 2, 1), -3, False),
35+
((1, 5, 2, 3), -2, True),
36+
]
37+
)
38+
def test_max_dim_int(self, input_shape, dim, keep_dims):
39+
class Max(nn.Module):
40+
def forward(self, x):
41+
return torch.ops.aten.max.dim(x, dim, keep_dims)[0]
42+
43+
inputs = [torch.randn(*input_shape)]
44+
self.run_test(
45+
Max(),
46+
inputs,
47+
)
48+
49+
@parameterized.expand(
50+
[
51+
((3, 2, 4), 1, True, torch.int, 0, 5),
52+
((2, 3, 4, 5), 2, False, torch.int32, -5, 0),
53+
((6, 7, 5, 4, 5), 4, False, torch.int32, -5, 5),
54+
]
55+
)
56+
def test_max_dim_int_int(self, input_shape, dim, keep_dims, dtype, low, high):
57+
class Max(nn.Module):
58+
def forward(self, x):
59+
return torch.ops.aten.max.dim(x, dim, keep_dims)[0]
60+
61+
inputs = [torch.randint(low, high, input_shape, dtype=dtype)]
62+
self.run_test(
63+
Max(),
64+
inputs,
65+
check_dtype=False,
2766
)
2867

2968

0 commit comments

Comments
 (0)