Skip to content

Commit b51c121

Browse files
committed
add overloads, update tests, and fix a bug
1 parent bf020cd commit b51c121

16 files changed

+537
-149
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 81 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,10 @@
11
import logging
22
from typing import Any, Dict, Optional, Sequence, Tuple, Union
33

4-
import tensorrt as trt
54
import torch
65
from torch.fx.node import Argument, Node, Target
76
from torch_tensorrt.dynamo._SourceIR import SourceIR
87
from torch_tensorrt.dynamo.conversion import impl
9-
from torch_tensorrt.dynamo.conversion.converter_utils import (
10-
cast_int_int_div_trt_tensor,
11-
cast_trt_tensor,
12-
)
13-
from torch_tensorrt.fx.converters import acc_ops_converters
148
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor
159

1610
from .converter_registry import dynamo_tensorrt_converter
@@ -48,58 +42,6 @@ def aten_ops_batch_norm(
4842
)
4943

5044

51-
@dynamo_tensorrt_converter(torch.ops.aten.div.default) # type: ignore[misc]
52-
@dynamo_tensorrt_converter(torch.ops.aten.div.Tensor_mode) # type: ignore[misc]
53-
@dynamo_tensorrt_converter(torch.ops.aten.div.Tensor) # type: ignore[misc]
54-
def aten_ops_div(
55-
network: TRTNetwork,
56-
target: Target,
57-
args: Tuple[Argument, ...],
58-
kwargs: Dict[str, Argument],
59-
name: str,
60-
) -> Union[TRTTensor, Sequence[TRTTensor]]:
61-
kwargs_new = {
62-
"input": args[0],
63-
"other": args[1],
64-
}
65-
# If both are TRTTensor, both are cast to float32
66-
if isinstance(args[0], TRTTensor) and isinstance(args[1], TRTTensor):
67-
kwargs_new["input"], kwargs_new["other"] = cast_int_int_div_trt_tensor(
68-
network,
69-
kwargs_new["input"],
70-
kwargs_new["other"],
71-
name,
72-
)
73-
# If one is TRTTensor, it is cast to float32
74-
elif isinstance(args[0], TRTTensor) and (
75-
kwargs_new["input"].dtype == trt.int8 or kwargs_new["input"].dtype == trt.int32
76-
):
77-
kwargs_new["input"] = cast_trt_tensor(
78-
network, kwargs_new["input"], trt.float32, name, target
79-
)
80-
elif isinstance(args[1], TRTTensor) and (
81-
kwargs_new["other"].dtype == trt.int8 or kwargs_new["other"].dtype == trt.int32
82-
):
83-
kwargs_new["other"] = cast_trt_tensor(
84-
network, kwargs_new["other"], trt.float32, name, target
85-
)
86-
rounding_mode = kwargs.get("rounding_mode")
87-
if rounding_mode is None:
88-
return acc_ops_converters.acc_ops_div(network, target, None, kwargs_new, name)
89-
elif rounding_mode == "floor":
90-
return acc_ops_converters.acc_ops_floor_div(
91-
network, target, None, kwargs_new, name
92-
)
93-
elif rounding_mode == "trunc":
94-
return impl.elementwise.trunc_div(
95-
network, target, SourceIR.ATEN, name, args[0], args[1]
96-
)
97-
else:
98-
raise RuntimeError(
99-
f"Target {target} does not support rounding mode {rounding_mode}"
100-
)
101-
102-
10345
def embedding_param_validator(embedding_node: Node) -> bool:
10446
scale_grad_by_freq = args_bounds_check(embedding_node.args, 3)
10547
sparse = args_bounds_check(embedding_node.args, 4)
@@ -982,24 +924,39 @@ def aten_ops_isinf(
982924

983925

984926
@dynamo_tensorrt_converter(torch.ops.aten.add.Tensor)
927+
@dynamo_tensorrt_converter(torch.ops.aten.add.Scalar)
985928
def aten_ops_add(
986929
network: TRTNetwork,
987930
target: Target,
988931
args: Tuple[Argument, ...],
989932
kwargs: Dict[str, Argument],
990933
name: str,
991934
) -> Union[TRTTensor, Sequence[TRTTensor]]:
935+
other = args[1]
936+
alpha = kwargs.get("alpha", 1)
937+
938+
if alpha != 1:
939+
other = impl.elementwise.mul(
940+
network,
941+
target,
942+
SourceIR.ATEN,
943+
name,
944+
other,
945+
alpha,
946+
)
947+
992948
return impl.elementwise.add(
993949
network,
994950
target,
995951
SourceIR.ATEN,
996952
name,
997953
args[0],
998-
args[1],
954+
other,
999955
)
1000956

1001957

1002958
@dynamo_tensorrt_converter(torch.ops.aten.mul.Tensor)
959+
@dynamo_tensorrt_converter(torch.ops.aten.mul.Scalar)
1003960
def aten_ops_mul(
1004961
network: TRTNetwork,
1005962
target: Target,
@@ -1054,43 +1011,86 @@ def aten_ops_min(
10541011

10551012

10561013
@dynamo_tensorrt_converter(torch.ops.aten.sub.Tensor)
1014+
@dynamo_tensorrt_converter(torch.ops.aten.sub.Scalar)
10571015
def aten_ops_sub(
10581016
network: TRTNetwork,
10591017
target: Target,
10601018
args: Tuple[Argument, ...],
10611019
kwargs: Dict[str, Argument],
10621020
name: str,
10631021
) -> Union[TRTTensor, Sequence[TRTTensor]]:
1022+
other = args[1]
1023+
alpha = kwargs.get("alpha", 1)
1024+
1025+
if alpha != 1:
1026+
other = impl.elementwise.mul(
1027+
network,
1028+
target,
1029+
SourceIR.ATEN,
1030+
name,
1031+
other,
1032+
alpha,
1033+
)
1034+
10641035
return impl.elementwise.sub(
10651036
network,
10661037
target,
10671038
SourceIR.ATEN,
10681039
name,
10691040
args[0],
1070-
args[1],
1041+
other,
10711042
)
10721043

10731044

1074-
# TODO: keep this or line 54...?
1075-
# @dynamo_tensorrt_converter(torch.ops.aten.div.Tensor)
1076-
# def aten_ops_div(
1077-
# network: TRTNetwork,
1078-
# target: Target,
1079-
# args: Tuple[Argument, ...],
1080-
# kwargs: Dict[str, Argument],
1081-
# name: str,
1082-
# ) -> Union[TRTTensor, Sequence[TRTTensor]]:
1083-
# return impl.elementwise.div(
1084-
# network,
1085-
# target,
1086-
# SourceIR.ATEN,
1087-
# name,
1088-
# args[0],
1089-
# args[1],
1090-
# )
1045+
@dynamo_tensorrt_converter(torch.ops.aten.div.Tensor)
1046+
@dynamo_tensorrt_converter(torch.ops.aten.div.Tensor_mode)
1047+
@dynamo_tensorrt_converter(torch.ops.aten.div.Scalar)
1048+
@dynamo_tensorrt_converter(torch.ops.aten.div.Scalar_mode)
1049+
def aten_ops_div(
1050+
network: TRTNetwork,
1051+
target: Target,
1052+
args: Tuple[Argument, ...],
1053+
kwargs: Dict[str, Argument],
1054+
name: str,
1055+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
1056+
rounding_mode = kwargs.get("rounding_mode")
1057+
1058+
if rounding_mode is None:
1059+
return impl.elementwise.div(
1060+
network,
1061+
target,
1062+
SourceIR.ATEN,
1063+
name,
1064+
args[0],
1065+
args[1],
1066+
)
1067+
elif rounding_mode == "floor":
1068+
return impl.elementwise.floor_divide(
1069+
network,
1070+
target,
1071+
SourceIR.ATEN,
1072+
name,
1073+
args[0],
1074+
args[1],
1075+
)
1076+
elif rounding_mode == "trunc":
1077+
return impl.elementwise.trunc_div(
1078+
network,
1079+
target,
1080+
SourceIR.ATEN,
1081+
name,
1082+
args[0],
1083+
args[1],
1084+
)
1085+
else:
1086+
raise RuntimeError(
1087+
f"Target {target} does not support rounding mode {rounding_mode}"
1088+
)
10911089

10921090

10931091
@dynamo_tensorrt_converter(torch.ops.aten.pow.Tensor_Tensor)
1092+
@dynamo_tensorrt_converter(torch.ops.aten.pow.Scalar)
1093+
@dynamo_tensorrt_converter(torch.ops.aten.pow.Tensor_Scalar)
10941094
def aten_ops_pow(
10951095
network: TRTNetwork,
10961096
target: Target,
@@ -1109,6 +1109,7 @@ def aten_ops_pow(
11091109

11101110

11111111
@dynamo_tensorrt_converter(torch.ops.aten.floor_divide.default)
1112+
@dynamo_tensorrt_converter(torch.ops.aten.floor_divide.Scalar)
11121113
def aten_ops_floor_div(
11131114
network: TRTNetwork,
11141115
target: Target,
@@ -1181,6 +1182,7 @@ def aten_ops_logical_xor(
11811182

11821183

11831184
@dynamo_tensorrt_converter(torch.ops.aten.eq.Tensor)
1185+
@dynamo_tensorrt_converter(torch.ops.aten.eq.Scalar)
11841186
def aten_ops_equal(
11851187
network: TRTNetwork,
11861188
target: Target,
@@ -1199,6 +1201,7 @@ def aten_ops_equal(
11991201

12001202

12011203
@dynamo_tensorrt_converter(torch.ops.aten.gt.Tensor)
1204+
@dynamo_tensorrt_converter(torch.ops.aten.gt.Scalar)
12021205
def aten_ops_greater(
12031206
network: TRTNetwork,
12041207
target: Target,
@@ -1217,6 +1220,7 @@ def aten_ops_greater(
12171220

12181221

12191222
@dynamo_tensorrt_converter(torch.ops.aten.lt.Tensor)
1223+
@dynamo_tensorrt_converter(torch.ops.aten.lt.Scalar)
12201224
def aten_ops_less(
12211225
network: TRTNetwork,
12221226
target: Target,

py/torch_tensorrt/dynamo/conversion/converter_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def cast_trt_tensor(
9797

9898
if input_val.dtype != trt_dtype:
9999
source_ir = source_ir if source_ir is not None else SourceIR.UNKNOWN
100-
target_str = ConverterRegistry.qualified_name_or_str(target)
100+
target_str = ConverterRegistry.qualified_name_or_str(ConverterRegistry, target)
101101
target_name = f"{source_ir}_ops{('.' + target_str) if target_str else ''}"
102102

103103
identity_layer = network.add_identity(input_val)

tests/py/dynamo/conversion/test_add_aten.py

Lines changed: 61 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,79 @@
11
import torch
22
import torch.nn as nn
3-
from .harness import DispatchTestCase
43
from parameterized import parameterized
54
from torch.testing._internal.common_utils import run_tests
65
from torch_tensorrt import Input
76

7+
from .harness import DispatchTestCase
8+
89

910
class TestAddConverter(DispatchTestCase):
1011
@parameterized.expand(
1112
[
12-
("2d_dim_alpha", (2, 1), 2),
13-
("3d_dim_alpha", (2, 1, 2), 2),
13+
("2d", (2, 1)),
14+
("3d", (2, 1, 2)),
15+
]
16+
)
17+
def test_add_tensor(self, _, shape):
18+
class add(nn.Module):
19+
def forward(self, lhs_val, rhs_val):
20+
return torch.add(lhs_val, rhs_val)
21+
22+
inputs = [torch.randn(shape), torch.randn(shape)]
23+
self.run_test(
24+
add(),
25+
inputs,
26+
expected_ops={torch.ops.aten.add.Tensor},
27+
)
28+
29+
@parameterized.expand(
30+
[
31+
("2d", (2, 1), 1),
32+
("3d", (2, 1, 2), 2.0),
1433
]
1534
)
16-
def test_add(self, _, x, alpha):
35+
def test_add_tensor_alpha(self, _, shape, alpha):
1736
class add(nn.Module):
1837
def forward(self, lhs_val, rhs_val):
19-
return lhs_val + rhs_val
38+
return torch.add(lhs_val, rhs_val, alpha=alpha)
39+
40+
inputs = [torch.randn(shape), torch.randn(shape)]
41+
self.run_test(
42+
add(),
43+
inputs,
44+
expected_ops={torch.ops.aten.add.Tensor},
45+
)
46+
47+
@parameterized.expand(
48+
[
49+
("2d", (2, 1), 1.0),
50+
("3d", (2, 1, 2), 2),
51+
]
52+
)
53+
def test_add_scalar(self, _, shape, scalar):
54+
class add(nn.Module):
55+
def forward(self, lhs_val):
56+
return torch.add(lhs_val, scalar)
57+
58+
inputs = [torch.randn(shape)]
59+
self.run_test(
60+
add(),
61+
inputs,
62+
expected_ops={torch.ops.aten.add.Tensor},
63+
)
64+
65+
@parameterized.expand(
66+
[
67+
("2d", (2, 1), 1.0, 1.0),
68+
("3d", (2, 1, 2), 2, 2),
69+
]
70+
)
71+
def test_add_scalar_alpha(self, _, shape, scalar, alpha):
72+
class add(nn.Module):
73+
def forward(self, lhs_val):
74+
return torch.add(lhs_val, scalar, alpha=alpha)
2075

21-
inputs = [torch.randn(x) + 1, torch.randn(x) + 1]
76+
inputs = [torch.randn(shape)]
2277
self.run_test(
2378
add(),
2479
inputs,

0 commit comments

Comments
 (0)