|
1 | 1 | import logging
|
2 | 2 | from typing import Any, Dict, Optional, Sequence, Tuple, Union
|
3 | 3 |
|
4 |
| -import tensorrt as trt |
5 | 4 | import torch
|
6 | 5 | from torch.fx.node import Argument, Node, Target
|
7 | 6 | from torch_tensorrt.dynamo._SourceIR import SourceIR
|
8 | 7 | from torch_tensorrt.dynamo.conversion import impl
|
9 |
| -from torch_tensorrt.dynamo.conversion.converter_utils import ( |
10 |
| - cast_int_int_div_trt_tensor, |
11 |
| - cast_trt_tensor, |
12 |
| -) |
13 |
| -from torch_tensorrt.fx.converters import acc_ops_converters |
14 | 8 | from torch_tensorrt.fx.types import TRTNetwork, TRTTensor
|
15 | 9 |
|
16 | 10 | from .converter_registry import dynamo_tensorrt_converter
|
@@ -48,58 +42,6 @@ def aten_ops_batch_norm(
|
48 | 42 | )
|
49 | 43 |
|
50 | 44 |
|
51 |
| -@dynamo_tensorrt_converter(torch.ops.aten.div.default) # type: ignore[misc] |
52 |
| -@dynamo_tensorrt_converter(torch.ops.aten.div.Tensor_mode) # type: ignore[misc] |
53 |
| -@dynamo_tensorrt_converter(torch.ops.aten.div.Tensor) # type: ignore[misc] |
54 |
| -def aten_ops_div( |
55 |
| - network: TRTNetwork, |
56 |
| - target: Target, |
57 |
| - args: Tuple[Argument, ...], |
58 |
| - kwargs: Dict[str, Argument], |
59 |
| - name: str, |
60 |
| -) -> Union[TRTTensor, Sequence[TRTTensor]]: |
61 |
| - kwargs_new = { |
62 |
| - "input": args[0], |
63 |
| - "other": args[1], |
64 |
| - } |
65 |
| - # If both are TRTTensor, both are cast to float32 |
66 |
| - if isinstance(args[0], TRTTensor) and isinstance(args[1], TRTTensor): |
67 |
| - kwargs_new["input"], kwargs_new["other"] = cast_int_int_div_trt_tensor( |
68 |
| - network, |
69 |
| - kwargs_new["input"], |
70 |
| - kwargs_new["other"], |
71 |
| - name, |
72 |
| - ) |
73 |
| - # If one is TRTTensor, it is cast to float32 |
74 |
| - elif isinstance(args[0], TRTTensor) and ( |
75 |
| - kwargs_new["input"].dtype == trt.int8 or kwargs_new["input"].dtype == trt.int32 |
76 |
| - ): |
77 |
| - kwargs_new["input"] = cast_trt_tensor( |
78 |
| - network, kwargs_new["input"], trt.float32, name, target |
79 |
| - ) |
80 |
| - elif isinstance(args[1], TRTTensor) and ( |
81 |
| - kwargs_new["other"].dtype == trt.int8 or kwargs_new["other"].dtype == trt.int32 |
82 |
| - ): |
83 |
| - kwargs_new["other"] = cast_trt_tensor( |
84 |
| - network, kwargs_new["other"], trt.float32, name, target |
85 |
| - ) |
86 |
| - rounding_mode = kwargs.get("rounding_mode") |
87 |
| - if rounding_mode is None: |
88 |
| - return acc_ops_converters.acc_ops_div(network, target, None, kwargs_new, name) |
89 |
| - elif rounding_mode == "floor": |
90 |
| - return acc_ops_converters.acc_ops_floor_div( |
91 |
| - network, target, None, kwargs_new, name |
92 |
| - ) |
93 |
| - elif rounding_mode == "trunc": |
94 |
| - return impl.elementwise.trunc_div( |
95 |
| - network, target, SourceIR.ATEN, name, args[0], args[1] |
96 |
| - ) |
97 |
| - else: |
98 |
| - raise RuntimeError( |
99 |
| - f"Target {target} does not support rounding mode {rounding_mode}" |
100 |
| - ) |
101 |
| - |
102 |
| - |
103 | 45 | def embedding_param_validator(embedding_node: Node) -> bool:
|
104 | 46 | scale_grad_by_freq = args_bounds_check(embedding_node.args, 3)
|
105 | 47 | sparse = args_bounds_check(embedding_node.args, 4)
|
@@ -1004,6 +946,321 @@ def aten_ops_isinf(
|
1004 | 946 | )
|
1005 | 947 |
|
1006 | 948 |
|
| 949 | +@dynamo_tensorrt_converter(torch.ops.aten.add.Tensor) |
| 950 | +@dynamo_tensorrt_converter(torch.ops.aten.add.Scalar) |
| 951 | +def aten_ops_add( |
| 952 | + network: TRTNetwork, |
| 953 | + target: Target, |
| 954 | + args: Tuple[Argument, ...], |
| 955 | + kwargs: Dict[str, Argument], |
| 956 | + name: str, |
| 957 | +) -> Union[TRTTensor, Sequence[TRTTensor]]: |
| 958 | + other = args[1] |
| 959 | + alpha = kwargs.get("alpha", 1) |
| 960 | + |
| 961 | + if alpha != 1: |
| 962 | + other = impl.elementwise.mul( |
| 963 | + network, |
| 964 | + target, |
| 965 | + SourceIR.ATEN, |
| 966 | + name, |
| 967 | + other, |
| 968 | + alpha, |
| 969 | + ) |
| 970 | + |
| 971 | + return impl.elementwise.add( |
| 972 | + network, |
| 973 | + target, |
| 974 | + SourceIR.ATEN, |
| 975 | + name, |
| 976 | + args[0], |
| 977 | + other, |
| 978 | + ) |
| 979 | + |
| 980 | + |
| 981 | +@dynamo_tensorrt_converter(torch.ops.aten.mul.Tensor) |
| 982 | +@dynamo_tensorrt_converter(torch.ops.aten.mul.Scalar) |
| 983 | +def aten_ops_mul( |
| 984 | + network: TRTNetwork, |
| 985 | + target: Target, |
| 986 | + args: Tuple[Argument, ...], |
| 987 | + kwargs: Dict[str, Argument], |
| 988 | + name: str, |
| 989 | +) -> Union[TRTTensor, Sequence[TRTTensor]]: |
| 990 | + return impl.elementwise.mul( |
| 991 | + network, |
| 992 | + target, |
| 993 | + SourceIR.ATEN, |
| 994 | + name, |
| 995 | + args[0], |
| 996 | + args[1], |
| 997 | + ) |
| 998 | + |
| 999 | + |
| 1000 | +@dynamo_tensorrt_converter(torch.ops.aten.maximum.default) |
| 1001 | +def aten_ops_max( |
| 1002 | + network: TRTNetwork, |
| 1003 | + target: Target, |
| 1004 | + args: Tuple[Argument, ...], |
| 1005 | + kwargs: Dict[str, Argument], |
| 1006 | + name: str, |
| 1007 | +) -> Union[TRTTensor, Sequence[TRTTensor]]: |
| 1008 | + return impl.elementwise.max( |
| 1009 | + network, |
| 1010 | + target, |
| 1011 | + SourceIR.ATEN, |
| 1012 | + name, |
| 1013 | + args[0], |
| 1014 | + args[1], |
| 1015 | + ) |
| 1016 | + |
| 1017 | + |
| 1018 | +@dynamo_tensorrt_converter(torch.ops.aten.minimum.default) |
| 1019 | +def aten_ops_min( |
| 1020 | + network: TRTNetwork, |
| 1021 | + target: Target, |
| 1022 | + args: Tuple[Argument, ...], |
| 1023 | + kwargs: Dict[str, Argument], |
| 1024 | + name: str, |
| 1025 | +) -> Union[TRTTensor, Sequence[TRTTensor]]: |
| 1026 | + return impl.elementwise.min( |
| 1027 | + network, |
| 1028 | + target, |
| 1029 | + SourceIR.ATEN, |
| 1030 | + name, |
| 1031 | + args[0], |
| 1032 | + args[1], |
| 1033 | + ) |
| 1034 | + |
| 1035 | + |
| 1036 | +@dynamo_tensorrt_converter(torch.ops.aten.sub.Tensor) |
| 1037 | +@dynamo_tensorrt_converter(torch.ops.aten.sub.Scalar) |
| 1038 | +def aten_ops_sub( |
| 1039 | + network: TRTNetwork, |
| 1040 | + target: Target, |
| 1041 | + args: Tuple[Argument, ...], |
| 1042 | + kwargs: Dict[str, Argument], |
| 1043 | + name: str, |
| 1044 | +) -> Union[TRTTensor, Sequence[TRTTensor]]: |
| 1045 | + other = args[1] |
| 1046 | + alpha = kwargs.get("alpha", 1) |
| 1047 | + |
| 1048 | + if alpha != 1: |
| 1049 | + other = impl.elementwise.mul( |
| 1050 | + network, |
| 1051 | + target, |
| 1052 | + SourceIR.ATEN, |
| 1053 | + name, |
| 1054 | + other, |
| 1055 | + alpha, |
| 1056 | + ) |
| 1057 | + |
| 1058 | + return impl.elementwise.sub( |
| 1059 | + network, |
| 1060 | + target, |
| 1061 | + SourceIR.ATEN, |
| 1062 | + name, |
| 1063 | + args[0], |
| 1064 | + other, |
| 1065 | + ) |
| 1066 | + |
| 1067 | + |
| 1068 | +@dynamo_tensorrt_converter(torch.ops.aten.div.Tensor) |
| 1069 | +@dynamo_tensorrt_converter(torch.ops.aten.div.Tensor_mode) |
| 1070 | +@dynamo_tensorrt_converter(torch.ops.aten.div.Scalar) |
| 1071 | +@dynamo_tensorrt_converter(torch.ops.aten.div.Scalar_mode) |
| 1072 | +def aten_ops_div( |
| 1073 | + network: TRTNetwork, |
| 1074 | + target: Target, |
| 1075 | + args: Tuple[Argument, ...], |
| 1076 | + kwargs: Dict[str, Argument], |
| 1077 | + name: str, |
| 1078 | +) -> Union[TRTTensor, Sequence[TRTTensor]]: |
| 1079 | + rounding_mode = kwargs.get("rounding_mode") |
| 1080 | + |
| 1081 | + if rounding_mode is None: |
| 1082 | + return impl.elementwise.div( |
| 1083 | + network, |
| 1084 | + target, |
| 1085 | + SourceIR.ATEN, |
| 1086 | + name, |
| 1087 | + args[0], |
| 1088 | + args[1], |
| 1089 | + ) |
| 1090 | + elif rounding_mode == "floor": |
| 1091 | + return impl.elementwise.floor_divide( |
| 1092 | + network, |
| 1093 | + target, |
| 1094 | + SourceIR.ATEN, |
| 1095 | + name, |
| 1096 | + args[0], |
| 1097 | + args[1], |
| 1098 | + ) |
| 1099 | + elif rounding_mode == "trunc": |
| 1100 | + return impl.elementwise.trunc_div( |
| 1101 | + network, |
| 1102 | + target, |
| 1103 | + SourceIR.ATEN, |
| 1104 | + name, |
| 1105 | + args[0], |
| 1106 | + args[1], |
| 1107 | + ) |
| 1108 | + else: |
| 1109 | + raise RuntimeError( |
| 1110 | + f"Target {target} does not support rounding mode {rounding_mode}" |
| 1111 | + ) |
| 1112 | + |
| 1113 | + |
| 1114 | +@dynamo_tensorrt_converter(torch.ops.aten.pow.Tensor_Tensor) |
| 1115 | +@dynamo_tensorrt_converter(torch.ops.aten.pow.Scalar) |
| 1116 | +@dynamo_tensorrt_converter(torch.ops.aten.pow.Tensor_Scalar) |
| 1117 | +def aten_ops_pow( |
| 1118 | + network: TRTNetwork, |
| 1119 | + target: Target, |
| 1120 | + args: Tuple[Argument, ...], |
| 1121 | + kwargs: Dict[str, Argument], |
| 1122 | + name: str, |
| 1123 | +) -> Union[TRTTensor, Sequence[TRTTensor]]: |
| 1124 | + return impl.elementwise.pow( |
| 1125 | + network, |
| 1126 | + target, |
| 1127 | + SourceIR.ATEN, |
| 1128 | + name, |
| 1129 | + args[0], |
| 1130 | + args[1], |
| 1131 | + ) |
| 1132 | + |
| 1133 | + |
| 1134 | +@dynamo_tensorrt_converter(torch.ops.aten.floor_divide.default) |
| 1135 | +@dynamo_tensorrt_converter(torch.ops.aten.floor_divide.Scalar) |
| 1136 | +def aten_ops_floor_div( |
| 1137 | + network: TRTNetwork, |
| 1138 | + target: Target, |
| 1139 | + args: Tuple[Argument, ...], |
| 1140 | + kwargs: Dict[str, Argument], |
| 1141 | + name: str, |
| 1142 | +) -> Union[TRTTensor, Sequence[TRTTensor]]: |
| 1143 | + return impl.elementwise.floor_divide( |
| 1144 | + network, |
| 1145 | + target, |
| 1146 | + SourceIR.ATEN, |
| 1147 | + name, |
| 1148 | + args[0], |
| 1149 | + args[1], |
| 1150 | + ) |
| 1151 | + |
| 1152 | + |
| 1153 | +@dynamo_tensorrt_converter(torch.ops.aten.logical_and.default) |
| 1154 | +def aten_ops_logical_and( |
| 1155 | + network: TRTNetwork, |
| 1156 | + target: Target, |
| 1157 | + args: Tuple[Argument, ...], |
| 1158 | + kwargs: Dict[str, Argument], |
| 1159 | + name: str, |
| 1160 | +) -> Union[TRTTensor, Sequence[TRTTensor]]: |
| 1161 | + return impl.elementwise.logical_and( |
| 1162 | + network, |
| 1163 | + target, |
| 1164 | + SourceIR.ATEN, |
| 1165 | + name, |
| 1166 | + args[0], |
| 1167 | + args[1], |
| 1168 | + ) |
| 1169 | + |
| 1170 | + |
| 1171 | +@dynamo_tensorrt_converter(torch.ops.aten.logical_or.default) |
| 1172 | +def aten_ops_logical_or( |
| 1173 | + network: TRTNetwork, |
| 1174 | + target: Target, |
| 1175 | + args: Tuple[Argument, ...], |
| 1176 | + kwargs: Dict[str, Argument], |
| 1177 | + name: str, |
| 1178 | +) -> Union[TRTTensor, Sequence[TRTTensor]]: |
| 1179 | + return impl.elementwise.logical_or( |
| 1180 | + network, |
| 1181 | + target, |
| 1182 | + SourceIR.ATEN, |
| 1183 | + name, |
| 1184 | + args[0], |
| 1185 | + args[1], |
| 1186 | + ) |
| 1187 | + |
| 1188 | + |
| 1189 | +@dynamo_tensorrt_converter(torch.ops.aten.logical_xor.default) |
| 1190 | +def aten_ops_logical_xor( |
| 1191 | + network: TRTNetwork, |
| 1192 | + target: Target, |
| 1193 | + args: Tuple[Argument, ...], |
| 1194 | + kwargs: Dict[str, Argument], |
| 1195 | + name: str, |
| 1196 | +) -> Union[TRTTensor, Sequence[TRTTensor]]: |
| 1197 | + return impl.elementwise.logical_xor( |
| 1198 | + network, |
| 1199 | + target, |
| 1200 | + SourceIR.ATEN, |
| 1201 | + name, |
| 1202 | + args[0], |
| 1203 | + args[1], |
| 1204 | + ) |
| 1205 | + |
| 1206 | + |
| 1207 | +@dynamo_tensorrt_converter(torch.ops.aten.eq.Tensor) |
| 1208 | +@dynamo_tensorrt_converter(torch.ops.aten.eq.Scalar) |
| 1209 | +def aten_ops_equal( |
| 1210 | + network: TRTNetwork, |
| 1211 | + target: Target, |
| 1212 | + args: Tuple[Argument, ...], |
| 1213 | + kwargs: Dict[str, Argument], |
| 1214 | + name: str, |
| 1215 | +) -> Union[TRTTensor, Sequence[TRTTensor]]: |
| 1216 | + return impl.elementwise.eq( |
| 1217 | + network, |
| 1218 | + target, |
| 1219 | + SourceIR.ATEN, |
| 1220 | + name, |
| 1221 | + args[0], |
| 1222 | + args[1], |
| 1223 | + ) |
| 1224 | + |
| 1225 | + |
| 1226 | +@dynamo_tensorrt_converter(torch.ops.aten.gt.Tensor) |
| 1227 | +@dynamo_tensorrt_converter(torch.ops.aten.gt.Scalar) |
| 1228 | +def aten_ops_greater( |
| 1229 | + network: TRTNetwork, |
| 1230 | + target: Target, |
| 1231 | + args: Tuple[Argument, ...], |
| 1232 | + kwargs: Dict[str, Argument], |
| 1233 | + name: str, |
| 1234 | +) -> Union[TRTTensor, Sequence[TRTTensor]]: |
| 1235 | + return impl.elementwise.gt( |
| 1236 | + network, |
| 1237 | + target, |
| 1238 | + SourceIR.ATEN, |
| 1239 | + name, |
| 1240 | + args[0], |
| 1241 | + args[1], |
| 1242 | + ) |
| 1243 | + |
| 1244 | + |
| 1245 | +@dynamo_tensorrt_converter(torch.ops.aten.lt.Tensor) |
| 1246 | +@dynamo_tensorrt_converter(torch.ops.aten.lt.Scalar) |
| 1247 | +def aten_ops_less( |
| 1248 | + network: TRTNetwork, |
| 1249 | + target: Target, |
| 1250 | + args: Tuple[Argument, ...], |
| 1251 | + kwargs: Dict[str, Argument], |
| 1252 | + name: str, |
| 1253 | +) -> Union[TRTTensor, Sequence[TRTTensor]]: |
| 1254 | + return impl.elementwise.lt( |
| 1255 | + network, |
| 1256 | + target, |
| 1257 | + SourceIR.ATEN, |
| 1258 | + name, |
| 1259 | + args[0], |
| 1260 | + args[1], |
| 1261 | + ) |
| 1262 | + |
| 1263 | + |
1007 | 1264 | def conv_param_validator(conv_node: Node) -> bool:
|
1008 | 1265 | return (not conv_node.args[6]) and (conv_node.args[7] in ([0], [0, 0], [0, 0, 0]))
|
1009 | 1266 |
|
|
0 commit comments