Skip to content

Commit 40f8064

Browse files
authored
feat: support many elementwise dynamo converters (#2263)
1 parent ba2a300 commit 40f8064

18 files changed

+1394
-66
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 315 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,10 @@
11
import logging
22
from typing import Any, Dict, Optional, Sequence, Tuple, Union
33

4-
import tensorrt as trt
54
import torch
65
from torch.fx.node import Argument, Node, Target
76
from torch_tensorrt.dynamo._SourceIR import SourceIR
87
from torch_tensorrt.dynamo.conversion import impl
9-
from torch_tensorrt.dynamo.conversion.converter_utils import (
10-
cast_int_int_div_trt_tensor,
11-
cast_trt_tensor,
12-
)
13-
from torch_tensorrt.fx.converters import acc_ops_converters
148
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor
159

1610
from .converter_registry import dynamo_tensorrt_converter
@@ -48,58 +42,6 @@ def aten_ops_batch_norm(
4842
)
4943

5044

51-
@dynamo_tensorrt_converter(torch.ops.aten.div.default) # type: ignore[misc]
52-
@dynamo_tensorrt_converter(torch.ops.aten.div.Tensor_mode) # type: ignore[misc]
53-
@dynamo_tensorrt_converter(torch.ops.aten.div.Tensor) # type: ignore[misc]
54-
def aten_ops_div(
55-
network: TRTNetwork,
56-
target: Target,
57-
args: Tuple[Argument, ...],
58-
kwargs: Dict[str, Argument],
59-
name: str,
60-
) -> Union[TRTTensor, Sequence[TRTTensor]]:
61-
kwargs_new = {
62-
"input": args[0],
63-
"other": args[1],
64-
}
65-
# If both are TRTTensor, both are cast to float32
66-
if isinstance(args[0], TRTTensor) and isinstance(args[1], TRTTensor):
67-
kwargs_new["input"], kwargs_new["other"] = cast_int_int_div_trt_tensor(
68-
network,
69-
kwargs_new["input"],
70-
kwargs_new["other"],
71-
name,
72-
)
73-
# If one is TRTTensor, it is cast to float32
74-
elif isinstance(args[0], TRTTensor) and (
75-
kwargs_new["input"].dtype == trt.int8 or kwargs_new["input"].dtype == trt.int32
76-
):
77-
kwargs_new["input"] = cast_trt_tensor(
78-
network, kwargs_new["input"], trt.float32, name, target
79-
)
80-
elif isinstance(args[1], TRTTensor) and (
81-
kwargs_new["other"].dtype == trt.int8 or kwargs_new["other"].dtype == trt.int32
82-
):
83-
kwargs_new["other"] = cast_trt_tensor(
84-
network, kwargs_new["other"], trt.float32, name, target
85-
)
86-
rounding_mode = kwargs.get("rounding_mode")
87-
if rounding_mode is None:
88-
return acc_ops_converters.acc_ops_div(network, target, None, kwargs_new, name)
89-
elif rounding_mode == "floor":
90-
return acc_ops_converters.acc_ops_floor_div(
91-
network, target, None, kwargs_new, name
92-
)
93-
elif rounding_mode == "trunc":
94-
return impl.elementwise.trunc_div(
95-
network, target, SourceIR.ATEN, name, args[0], args[1]
96-
)
97-
else:
98-
raise RuntimeError(
99-
f"Target {target} does not support rounding mode {rounding_mode}"
100-
)
101-
102-
10345
def embedding_param_validator(embedding_node: Node) -> bool:
10446
scale_grad_by_freq = args_bounds_check(embedding_node.args, 3)
10547
sparse = args_bounds_check(embedding_node.args, 4)
@@ -1004,6 +946,321 @@ def aten_ops_isinf(
1004946
)
1005947

1006948

949+
@dynamo_tensorrt_converter(torch.ops.aten.add.Tensor)
950+
@dynamo_tensorrt_converter(torch.ops.aten.add.Scalar)
951+
def aten_ops_add(
952+
network: TRTNetwork,
953+
target: Target,
954+
args: Tuple[Argument, ...],
955+
kwargs: Dict[str, Argument],
956+
name: str,
957+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
958+
other = args[1]
959+
alpha = kwargs.get("alpha", 1)
960+
961+
if alpha != 1:
962+
other = impl.elementwise.mul(
963+
network,
964+
target,
965+
SourceIR.ATEN,
966+
name,
967+
other,
968+
alpha,
969+
)
970+
971+
return impl.elementwise.add(
972+
network,
973+
target,
974+
SourceIR.ATEN,
975+
name,
976+
args[0],
977+
other,
978+
)
979+
980+
981+
@dynamo_tensorrt_converter(torch.ops.aten.mul.Tensor)
982+
@dynamo_tensorrt_converter(torch.ops.aten.mul.Scalar)
983+
def aten_ops_mul(
984+
network: TRTNetwork,
985+
target: Target,
986+
args: Tuple[Argument, ...],
987+
kwargs: Dict[str, Argument],
988+
name: str,
989+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
990+
return impl.elementwise.mul(
991+
network,
992+
target,
993+
SourceIR.ATEN,
994+
name,
995+
args[0],
996+
args[1],
997+
)
998+
999+
1000+
@dynamo_tensorrt_converter(torch.ops.aten.maximum.default)
1001+
def aten_ops_max(
1002+
network: TRTNetwork,
1003+
target: Target,
1004+
args: Tuple[Argument, ...],
1005+
kwargs: Dict[str, Argument],
1006+
name: str,
1007+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
1008+
return impl.elementwise.max(
1009+
network,
1010+
target,
1011+
SourceIR.ATEN,
1012+
name,
1013+
args[0],
1014+
args[1],
1015+
)
1016+
1017+
1018+
@dynamo_tensorrt_converter(torch.ops.aten.minimum.default)
1019+
def aten_ops_min(
1020+
network: TRTNetwork,
1021+
target: Target,
1022+
args: Tuple[Argument, ...],
1023+
kwargs: Dict[str, Argument],
1024+
name: str,
1025+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
1026+
return impl.elementwise.min(
1027+
network,
1028+
target,
1029+
SourceIR.ATEN,
1030+
name,
1031+
args[0],
1032+
args[1],
1033+
)
1034+
1035+
1036+
@dynamo_tensorrt_converter(torch.ops.aten.sub.Tensor)
1037+
@dynamo_tensorrt_converter(torch.ops.aten.sub.Scalar)
1038+
def aten_ops_sub(
1039+
network: TRTNetwork,
1040+
target: Target,
1041+
args: Tuple[Argument, ...],
1042+
kwargs: Dict[str, Argument],
1043+
name: str,
1044+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
1045+
other = args[1]
1046+
alpha = kwargs.get("alpha", 1)
1047+
1048+
if alpha != 1:
1049+
other = impl.elementwise.mul(
1050+
network,
1051+
target,
1052+
SourceIR.ATEN,
1053+
name,
1054+
other,
1055+
alpha,
1056+
)
1057+
1058+
return impl.elementwise.sub(
1059+
network,
1060+
target,
1061+
SourceIR.ATEN,
1062+
name,
1063+
args[0],
1064+
other,
1065+
)
1066+
1067+
1068+
@dynamo_tensorrt_converter(torch.ops.aten.div.Tensor)
1069+
@dynamo_tensorrt_converter(torch.ops.aten.div.Tensor_mode)
1070+
@dynamo_tensorrt_converter(torch.ops.aten.div.Scalar)
1071+
@dynamo_tensorrt_converter(torch.ops.aten.div.Scalar_mode)
1072+
def aten_ops_div(
1073+
network: TRTNetwork,
1074+
target: Target,
1075+
args: Tuple[Argument, ...],
1076+
kwargs: Dict[str, Argument],
1077+
name: str,
1078+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
1079+
rounding_mode = kwargs.get("rounding_mode")
1080+
1081+
if rounding_mode is None:
1082+
return impl.elementwise.div(
1083+
network,
1084+
target,
1085+
SourceIR.ATEN,
1086+
name,
1087+
args[0],
1088+
args[1],
1089+
)
1090+
elif rounding_mode == "floor":
1091+
return impl.elementwise.floor_divide(
1092+
network,
1093+
target,
1094+
SourceIR.ATEN,
1095+
name,
1096+
args[0],
1097+
args[1],
1098+
)
1099+
elif rounding_mode == "trunc":
1100+
return impl.elementwise.trunc_div(
1101+
network,
1102+
target,
1103+
SourceIR.ATEN,
1104+
name,
1105+
args[0],
1106+
args[1],
1107+
)
1108+
else:
1109+
raise RuntimeError(
1110+
f"Target {target} does not support rounding mode {rounding_mode}"
1111+
)
1112+
1113+
1114+
@dynamo_tensorrt_converter(torch.ops.aten.pow.Tensor_Tensor)
1115+
@dynamo_tensorrt_converter(torch.ops.aten.pow.Scalar)
1116+
@dynamo_tensorrt_converter(torch.ops.aten.pow.Tensor_Scalar)
1117+
def aten_ops_pow(
1118+
network: TRTNetwork,
1119+
target: Target,
1120+
args: Tuple[Argument, ...],
1121+
kwargs: Dict[str, Argument],
1122+
name: str,
1123+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
1124+
return impl.elementwise.pow(
1125+
network,
1126+
target,
1127+
SourceIR.ATEN,
1128+
name,
1129+
args[0],
1130+
args[1],
1131+
)
1132+
1133+
1134+
@dynamo_tensorrt_converter(torch.ops.aten.floor_divide.default)
1135+
@dynamo_tensorrt_converter(torch.ops.aten.floor_divide.Scalar)
1136+
def aten_ops_floor_div(
1137+
network: TRTNetwork,
1138+
target: Target,
1139+
args: Tuple[Argument, ...],
1140+
kwargs: Dict[str, Argument],
1141+
name: str,
1142+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
1143+
return impl.elementwise.floor_divide(
1144+
network,
1145+
target,
1146+
SourceIR.ATEN,
1147+
name,
1148+
args[0],
1149+
args[1],
1150+
)
1151+
1152+
1153+
@dynamo_tensorrt_converter(torch.ops.aten.logical_and.default)
1154+
def aten_ops_logical_and(
1155+
network: TRTNetwork,
1156+
target: Target,
1157+
args: Tuple[Argument, ...],
1158+
kwargs: Dict[str, Argument],
1159+
name: str,
1160+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
1161+
return impl.elementwise.logical_and(
1162+
network,
1163+
target,
1164+
SourceIR.ATEN,
1165+
name,
1166+
args[0],
1167+
args[1],
1168+
)
1169+
1170+
1171+
@dynamo_tensorrt_converter(torch.ops.aten.logical_or.default)
1172+
def aten_ops_logical_or(
1173+
network: TRTNetwork,
1174+
target: Target,
1175+
args: Tuple[Argument, ...],
1176+
kwargs: Dict[str, Argument],
1177+
name: str,
1178+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
1179+
return impl.elementwise.logical_or(
1180+
network,
1181+
target,
1182+
SourceIR.ATEN,
1183+
name,
1184+
args[0],
1185+
args[1],
1186+
)
1187+
1188+
1189+
@dynamo_tensorrt_converter(torch.ops.aten.logical_xor.default)
1190+
def aten_ops_logical_xor(
1191+
network: TRTNetwork,
1192+
target: Target,
1193+
args: Tuple[Argument, ...],
1194+
kwargs: Dict[str, Argument],
1195+
name: str,
1196+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
1197+
return impl.elementwise.logical_xor(
1198+
network,
1199+
target,
1200+
SourceIR.ATEN,
1201+
name,
1202+
args[0],
1203+
args[1],
1204+
)
1205+
1206+
1207+
@dynamo_tensorrt_converter(torch.ops.aten.eq.Tensor)
1208+
@dynamo_tensorrt_converter(torch.ops.aten.eq.Scalar)
1209+
def aten_ops_equal(
1210+
network: TRTNetwork,
1211+
target: Target,
1212+
args: Tuple[Argument, ...],
1213+
kwargs: Dict[str, Argument],
1214+
name: str,
1215+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
1216+
return impl.elementwise.eq(
1217+
network,
1218+
target,
1219+
SourceIR.ATEN,
1220+
name,
1221+
args[0],
1222+
args[1],
1223+
)
1224+
1225+
1226+
@dynamo_tensorrt_converter(torch.ops.aten.gt.Tensor)
1227+
@dynamo_tensorrt_converter(torch.ops.aten.gt.Scalar)
1228+
def aten_ops_greater(
1229+
network: TRTNetwork,
1230+
target: Target,
1231+
args: Tuple[Argument, ...],
1232+
kwargs: Dict[str, Argument],
1233+
name: str,
1234+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
1235+
return impl.elementwise.gt(
1236+
network,
1237+
target,
1238+
SourceIR.ATEN,
1239+
name,
1240+
args[0],
1241+
args[1],
1242+
)
1243+
1244+
1245+
@dynamo_tensorrt_converter(torch.ops.aten.lt.Tensor)
1246+
@dynamo_tensorrt_converter(torch.ops.aten.lt.Scalar)
1247+
def aten_ops_less(
1248+
network: TRTNetwork,
1249+
target: Target,
1250+
args: Tuple[Argument, ...],
1251+
kwargs: Dict[str, Argument],
1252+
name: str,
1253+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
1254+
return impl.elementwise.lt(
1255+
network,
1256+
target,
1257+
SourceIR.ATEN,
1258+
name,
1259+
args[0],
1260+
args[1],
1261+
)
1262+
1263+
10071264
def conv_param_validator(conv_node: Node) -> bool:
10081265
return (not conv_node.args[6]) and (conv_node.args[7] in ([0], [0, 0], [0, 0, 0]))
10091266

0 commit comments

Comments
 (0)