Skip to content

Commit bf020cd

Browse files
committed
feat: support many elementwise dynamo converters
add output_dtypes in test add util func and fix bugs
1 parent fe0d8e0 commit bf020cd

18 files changed

+1014
-5
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 253 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -981,6 +981,259 @@ def aten_ops_isinf(
981981
)
982982

983983

984+
@dynamo_tensorrt_converter(torch.ops.aten.add.Tensor)
985+
def aten_ops_add(
986+
network: TRTNetwork,
987+
target: Target,
988+
args: Tuple[Argument, ...],
989+
kwargs: Dict[str, Argument],
990+
name: str,
991+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
992+
return impl.elementwise.add(
993+
network,
994+
target,
995+
SourceIR.ATEN,
996+
name,
997+
args[0],
998+
args[1],
999+
)
1000+
1001+
1002+
@dynamo_tensorrt_converter(torch.ops.aten.mul.Tensor)
1003+
def aten_ops_mul(
1004+
network: TRTNetwork,
1005+
target: Target,
1006+
args: Tuple[Argument, ...],
1007+
kwargs: Dict[str, Argument],
1008+
name: str,
1009+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
1010+
return impl.elementwise.mul(
1011+
network,
1012+
target,
1013+
SourceIR.ATEN,
1014+
name,
1015+
args[0],
1016+
args[1],
1017+
)
1018+
1019+
1020+
@dynamo_tensorrt_converter(torch.ops.aten.maximum.default)
1021+
def aten_ops_max(
1022+
network: TRTNetwork,
1023+
target: Target,
1024+
args: Tuple[Argument, ...],
1025+
kwargs: Dict[str, Argument],
1026+
name: str,
1027+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
1028+
return impl.elementwise.max(
1029+
network,
1030+
target,
1031+
SourceIR.ATEN,
1032+
name,
1033+
args[0],
1034+
args[1],
1035+
)
1036+
1037+
1038+
@dynamo_tensorrt_converter(torch.ops.aten.minimum.default)
1039+
def aten_ops_min(
1040+
network: TRTNetwork,
1041+
target: Target,
1042+
args: Tuple[Argument, ...],
1043+
kwargs: Dict[str, Argument],
1044+
name: str,
1045+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
1046+
return impl.elementwise.min(
1047+
network,
1048+
target,
1049+
SourceIR.ATEN,
1050+
name,
1051+
args[0],
1052+
args[1],
1053+
)
1054+
1055+
1056+
@dynamo_tensorrt_converter(torch.ops.aten.sub.Tensor)
1057+
def aten_ops_sub(
1058+
network: TRTNetwork,
1059+
target: Target,
1060+
args: Tuple[Argument, ...],
1061+
kwargs: Dict[str, Argument],
1062+
name: str,
1063+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
1064+
return impl.elementwise.sub(
1065+
network,
1066+
target,
1067+
SourceIR.ATEN,
1068+
name,
1069+
args[0],
1070+
args[1],
1071+
)
1072+
1073+
1074+
# TODO: keep this or line 54...?
1075+
# @dynamo_tensorrt_converter(torch.ops.aten.div.Tensor)
1076+
# def aten_ops_div(
1077+
# network: TRTNetwork,
1078+
# target: Target,
1079+
# args: Tuple[Argument, ...],
1080+
# kwargs: Dict[str, Argument],
1081+
# name: str,
1082+
# ) -> Union[TRTTensor, Sequence[TRTTensor]]:
1083+
# return impl.elementwise.div(
1084+
# network,
1085+
# target,
1086+
# SourceIR.ATEN,
1087+
# name,
1088+
# args[0],
1089+
# args[1],
1090+
# )
1091+
1092+
1093+
@dynamo_tensorrt_converter(torch.ops.aten.pow.Tensor_Tensor)
1094+
def aten_ops_pow(
1095+
network: TRTNetwork,
1096+
target: Target,
1097+
args: Tuple[Argument, ...],
1098+
kwargs: Dict[str, Argument],
1099+
name: str,
1100+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
1101+
return impl.elementwise.pow(
1102+
network,
1103+
target,
1104+
SourceIR.ATEN,
1105+
name,
1106+
args[0],
1107+
args[1],
1108+
)
1109+
1110+
1111+
@dynamo_tensorrt_converter(torch.ops.aten.floor_divide.default)
1112+
def aten_ops_floor_div(
1113+
network: TRTNetwork,
1114+
target: Target,
1115+
args: Tuple[Argument, ...],
1116+
kwargs: Dict[str, Argument],
1117+
name: str,
1118+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
1119+
return impl.elementwise.floor_divide(
1120+
network,
1121+
target,
1122+
SourceIR.ATEN,
1123+
name,
1124+
args[0],
1125+
args[1],
1126+
)
1127+
1128+
1129+
@dynamo_tensorrt_converter(torch.ops.aten.logical_and.default)
1130+
def aten_ops_logical_and(
1131+
network: TRTNetwork,
1132+
target: Target,
1133+
args: Tuple[Argument, ...],
1134+
kwargs: Dict[str, Argument],
1135+
name: str,
1136+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
1137+
return impl.elementwise.logical_and(
1138+
network,
1139+
target,
1140+
SourceIR.ATEN,
1141+
name,
1142+
args[0],
1143+
args[1],
1144+
)
1145+
1146+
1147+
@dynamo_tensorrt_converter(torch.ops.aten.logical_or.default)
1148+
def aten_ops_logical_or(
1149+
network: TRTNetwork,
1150+
target: Target,
1151+
args: Tuple[Argument, ...],
1152+
kwargs: Dict[str, Argument],
1153+
name: str,
1154+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
1155+
return impl.elementwise.logical_or(
1156+
network,
1157+
target,
1158+
SourceIR.ATEN,
1159+
name,
1160+
args[0],
1161+
args[1],
1162+
)
1163+
1164+
1165+
@dynamo_tensorrt_converter(torch.ops.aten.logical_xor.default)
1166+
def aten_ops_logical_xor(
1167+
network: TRTNetwork,
1168+
target: Target,
1169+
args: Tuple[Argument, ...],
1170+
kwargs: Dict[str, Argument],
1171+
name: str,
1172+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
1173+
return impl.elementwise.logical_xor(
1174+
network,
1175+
target,
1176+
SourceIR.ATEN,
1177+
name,
1178+
args[0],
1179+
args[1],
1180+
)
1181+
1182+
1183+
@dynamo_tensorrt_converter(torch.ops.aten.eq.Tensor)
1184+
def aten_ops_equal(
1185+
network: TRTNetwork,
1186+
target: Target,
1187+
args: Tuple[Argument, ...],
1188+
kwargs: Dict[str, Argument],
1189+
name: str,
1190+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
1191+
return impl.elementwise.eq(
1192+
network,
1193+
target,
1194+
SourceIR.ATEN,
1195+
name,
1196+
args[0],
1197+
args[1],
1198+
)
1199+
1200+
1201+
@dynamo_tensorrt_converter(torch.ops.aten.gt.Tensor)
1202+
def aten_ops_greater(
1203+
network: TRTNetwork,
1204+
target: Target,
1205+
args: Tuple[Argument, ...],
1206+
kwargs: Dict[str, Argument],
1207+
name: str,
1208+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
1209+
return impl.elementwise.gt(
1210+
network,
1211+
target,
1212+
SourceIR.ATEN,
1213+
name,
1214+
args[0],
1215+
args[1],
1216+
)
1217+
1218+
1219+
@dynamo_tensorrt_converter(torch.ops.aten.lt.Tensor)
1220+
def aten_ops_less(
1221+
network: TRTNetwork,
1222+
target: Target,
1223+
args: Tuple[Argument, ...],
1224+
kwargs: Dict[str, Argument],
1225+
name: str,
1226+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
1227+
return impl.elementwise.lt(
1228+
network,
1229+
target,
1230+
SourceIR.ATEN,
1231+
name,
1232+
args[0],
1233+
args[1],
1234+
)
1235+
1236+
9841237
def conv_param_validator(conv_node: Node) -> bool:
9851238
return (not conv_node.args[6]) and (conv_node.args[7] in ([0], [0, 0], [0, 0, 0]))
9861239

py/torch_tensorrt/dynamo/conversion/converter_utils.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,3 +188,17 @@ def extend_attr_to_tuple(
188188
if isinstance(val, list):
189189
val = tuple(val)
190190
return val
191+
192+
193+
def trt_cast_int_to_float(network: TRTNetwork, name: str, tensor: TRTTensor):
194+
if tensor.dtype == trt.int8 or tensor.dtype == trt.int32:
195+
return cast_trt_tensor(network, tensor, trt.float32, name)
196+
197+
return tensor
198+
199+
200+
def trt_cast_int_or_float_to_bool(network: TRTNetwork, name: str, tensor: TRTTensor):
201+
if tensor.dtype != trt.bool:
202+
return cast_trt_tensor(network, tensor, trt.bool, name)
203+
204+
return tensor

py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import warnings
33
from typing import Any, Callable, Optional, Union
44

5+
import numpy as np
56
import tensorrt as trt
67
import torch
78
from torch.fx.node import Target
@@ -24,12 +25,30 @@ def get_python_op_from_trt_elementwise_op(
2425
return operator.add
2526
elif trt_op == trt.ElementWiseOperation.PROD:
2627
return operator.mul
28+
elif trt_op == trt.ElementWiseOperation.MAX:
29+
return lambda a, b: max(a, b)
30+
elif trt_op == trt.ElementWiseOperation.MIN:
31+
return lambda a, b: min(a, b)
2732
elif trt_op == trt.ElementWiseOperation.SUB:
2833
return operator.sub
2934
elif trt_op == trt.ElementWiseOperation.DIV:
3035
return operator.truediv
36+
elif trt_op == trt.ElementWiseOperation.POW:
37+
return operator.pow
3138
elif trt_op == trt.ElementWiseOperation.FLOOR_DIV:
3239
return operator.floordiv
40+
elif trt_op == trt.ElementWiseOperation.AND:
41+
return lambda a, b: a and b
42+
elif trt_op == trt.ElementWiseOperation.OR:
43+
return lambda a, b: a or b
44+
elif trt_op == trt.ElementWiseOperation.XOR:
45+
return lambda a, b: (a or b) and not (a and b)
46+
elif trt_op == trt.ElementWiseOperation.EQUAL:
47+
return operator.eq
48+
elif trt_op == trt.ElementWiseOperation.GREATER:
49+
return operator.gt
50+
elif trt_op == trt.ElementWiseOperation.LESS:
51+
return operator.lt
3352
else:
3453
raise RuntimeError(f"{trt_op} is not supported yet!")
3554

@@ -75,10 +94,10 @@ def convert_binary_elementwise(
7594
is_rhs_trt_tensor = False
7695

7796
if isinstance(lhs_val, TRTTensor):
78-
lhs_dtype = unified_dtype_converter(lhs_val.dtype, Frameworks.TORCH)
97+
lhs_dtype = unified_dtype_converter(lhs_val.dtype, Frameworks.NUMPY)
7998
is_lhs_trt_tensor = True
8099
if isinstance(rhs_val, TRTTensor):
81-
rhs_dtype = unified_dtype_converter(rhs_val.dtype, Frameworks.TORCH)
100+
rhs_dtype = unified_dtype_converter(rhs_val.dtype, Frameworks.NUMPY)
82101
is_rhs_trt_tensor = True
83102

84103
if not is_lhs_trt_tensor and not is_rhs_trt_tensor:
@@ -103,9 +122,9 @@ def convert_binary_elementwise(
103122
# dtype but we don't have a way to detect whether it makes sense for the
104123
# scalar to be float or half. Hence we go with the lhs dtype.
105124
if is_lhs_trt_tensor and isinstance(rhs_val, (float, int)):
106-
rhs_val = torch.tensor([rhs_val], dtype=lhs_dtype)
125+
rhs_val = np.array([rhs_val], dtype=lhs_dtype)
107126
if is_rhs_trt_tensor and isinstance(lhs_val, (float, int)):
108-
lhs_val = torch.tensor([lhs_val], dtype=rhs_dtype)
127+
lhs_val = np.array([lhs_val], dtype=rhs_dtype)
109128

110129
# When lhs is scalar, and rhs has shape [1,], then currently the assert
111130
# will fail because lhs shape has fewer dimensions than rhs shape. This

0 commit comments

Comments
 (0)