1
1
import logging
2
2
from typing import Any , Dict , Optional , Sequence , Tuple , Union
3
3
4
- import tensorrt as trt
5
4
import torch
6
5
from torch .fx .node import Argument , Node , Target
7
6
from torch_tensorrt .dynamo ._SourceIR import SourceIR
8
7
from torch_tensorrt .dynamo .conversion import impl
9
- from torch_tensorrt .dynamo .conversion .converter_utils import (
10
- cast_int_int_div_trt_tensor ,
11
- cast_trt_tensor ,
12
- )
13
- from torch_tensorrt .fx .converters import acc_ops_converters
14
8
from torch_tensorrt .fx .types import TRTNetwork , TRTTensor
15
9
16
10
from .converter_registry import dynamo_tensorrt_converter
@@ -48,58 +42,6 @@ def aten_ops_batch_norm(
48
42
)
49
43
50
44
51
- @dynamo_tensorrt_converter (torch .ops .aten .div .default ) # type: ignore[misc]
52
- @dynamo_tensorrt_converter (torch .ops .aten .div .Tensor_mode ) # type: ignore[misc]
53
- @dynamo_tensorrt_converter (torch .ops .aten .div .Tensor ) # type: ignore[misc]
54
- def aten_ops_div (
55
- network : TRTNetwork ,
56
- target : Target ,
57
- args : Tuple [Argument , ...],
58
- kwargs : Dict [str , Argument ],
59
- name : str ,
60
- ) -> Union [TRTTensor , Sequence [TRTTensor ]]:
61
- kwargs_new = {
62
- "input" : args [0 ],
63
- "other" : args [1 ],
64
- }
65
- # If both are TRTTensor, both are cast to float32
66
- if isinstance (args [0 ], TRTTensor ) and isinstance (args [1 ], TRTTensor ):
67
- kwargs_new ["input" ], kwargs_new ["other" ] = cast_int_int_div_trt_tensor (
68
- network ,
69
- kwargs_new ["input" ],
70
- kwargs_new ["other" ],
71
- name ,
72
- )
73
- # If one is TRTTensor, it is cast to float32
74
- elif isinstance (args [0 ], TRTTensor ) and (
75
- kwargs_new ["input" ].dtype == trt .int8 or kwargs_new ["input" ].dtype == trt .int32
76
- ):
77
- kwargs_new ["input" ] = cast_trt_tensor (
78
- network , kwargs_new ["input" ], trt .float32 , name , target
79
- )
80
- elif isinstance (args [1 ], TRTTensor ) and (
81
- kwargs_new ["other" ].dtype == trt .int8 or kwargs_new ["other" ].dtype == trt .int32
82
- ):
83
- kwargs_new ["other" ] = cast_trt_tensor (
84
- network , kwargs_new ["other" ], trt .float32 , name , target
85
- )
86
- rounding_mode = kwargs .get ("rounding_mode" )
87
- if rounding_mode is None :
88
- return acc_ops_converters .acc_ops_div (network , target , None , kwargs_new , name )
89
- elif rounding_mode == "floor" :
90
- return acc_ops_converters .acc_ops_floor_div (
91
- network , target , None , kwargs_new , name
92
- )
93
- elif rounding_mode == "trunc" :
94
- return impl .elementwise .trunc_div (
95
- network , target , SourceIR .ATEN , name , args [0 ], args [1 ]
96
- )
97
- else :
98
- raise RuntimeError (
99
- f"Target { target } does not support rounding mode { rounding_mode } "
100
- )
101
-
102
-
103
45
def embedding_param_validator (embedding_node : Node ) -> bool :
104
46
scale_grad_by_freq = args_bounds_check (embedding_node .args , 3 )
105
47
sparse = args_bounds_check (embedding_node .args , 4 )
@@ -982,24 +924,39 @@ def aten_ops_isinf(
982
924
983
925
984
926
@dynamo_tensorrt_converter (torch .ops .aten .add .Tensor )
927
+ @dynamo_tensorrt_converter (torch .ops .aten .add .Scalar )
985
928
def aten_ops_add (
986
929
network : TRTNetwork ,
987
930
target : Target ,
988
931
args : Tuple [Argument , ...],
989
932
kwargs : Dict [str , Argument ],
990
933
name : str ,
991
934
) -> Union [TRTTensor , Sequence [TRTTensor ]]:
935
+ other = args [1 ]
936
+ alpha = kwargs .get ("alpha" , 1 )
937
+
938
+ if alpha != 1 :
939
+ other = impl .elementwise .mul (
940
+ network ,
941
+ target ,
942
+ SourceIR .ATEN ,
943
+ name ,
944
+ other ,
945
+ alpha ,
946
+ )
947
+
992
948
return impl .elementwise .add (
993
949
network ,
994
950
target ,
995
951
SourceIR .ATEN ,
996
952
name ,
997
953
args [0 ],
998
- args [ 1 ] ,
954
+ other ,
999
955
)
1000
956
1001
957
1002
958
@dynamo_tensorrt_converter (torch .ops .aten .mul .Tensor )
959
+ @dynamo_tensorrt_converter (torch .ops .aten .mul .Scalar )
1003
960
def aten_ops_mul (
1004
961
network : TRTNetwork ,
1005
962
target : Target ,
@@ -1054,43 +1011,86 @@ def aten_ops_min(
1054
1011
1055
1012
1056
1013
@dynamo_tensorrt_converter (torch .ops .aten .sub .Tensor )
1014
+ @dynamo_tensorrt_converter (torch .ops .aten .sub .Scalar )
1057
1015
def aten_ops_sub (
1058
1016
network : TRTNetwork ,
1059
1017
target : Target ,
1060
1018
args : Tuple [Argument , ...],
1061
1019
kwargs : Dict [str , Argument ],
1062
1020
name : str ,
1063
1021
) -> Union [TRTTensor , Sequence [TRTTensor ]]:
1022
+ other = args [1 ]
1023
+ alpha = kwargs .get ("alpha" , 1 )
1024
+
1025
+ if alpha != 1 :
1026
+ other = impl .elementwise .mul (
1027
+ network ,
1028
+ target ,
1029
+ SourceIR .ATEN ,
1030
+ name ,
1031
+ other ,
1032
+ alpha ,
1033
+ )
1034
+
1064
1035
return impl .elementwise .sub (
1065
1036
network ,
1066
1037
target ,
1067
1038
SourceIR .ATEN ,
1068
1039
name ,
1069
1040
args [0 ],
1070
- args [ 1 ] ,
1041
+ other ,
1071
1042
)
1072
1043
1073
1044
1074
- # TODO: keep this or line 54...?
1075
- # @dynamo_tensorrt_converter(torch.ops.aten.div.Tensor)
1076
- # def aten_ops_div(
1077
- # network: TRTNetwork,
1078
- # target: Target,
1079
- # args: Tuple[Argument, ...],
1080
- # kwargs: Dict[str, Argument],
1081
- # name: str,
1082
- # ) -> Union[TRTTensor, Sequence[TRTTensor]]:
1083
- # return impl.elementwise.div(
1084
- # network,
1085
- # target,
1086
- # SourceIR.ATEN,
1087
- # name,
1088
- # args[0],
1089
- # args[1],
1090
- # )
1045
+ @dynamo_tensorrt_converter (torch .ops .aten .div .Tensor )
1046
+ @dynamo_tensorrt_converter (torch .ops .aten .div .Tensor_mode )
1047
+ @dynamo_tensorrt_converter (torch .ops .aten .div .Scalar )
1048
+ @dynamo_tensorrt_converter (torch .ops .aten .div .Scalar_mode )
1049
+ def aten_ops_div (
1050
+ network : TRTNetwork ,
1051
+ target : Target ,
1052
+ args : Tuple [Argument , ...],
1053
+ kwargs : Dict [str , Argument ],
1054
+ name : str ,
1055
+ ) -> Union [TRTTensor , Sequence [TRTTensor ]]:
1056
+ rounding_mode = kwargs .get ("rounding_mode" )
1057
+
1058
+ if rounding_mode is None :
1059
+ return impl .elementwise .div (
1060
+ network ,
1061
+ target ,
1062
+ SourceIR .ATEN ,
1063
+ name ,
1064
+ args [0 ],
1065
+ args [1 ],
1066
+ )
1067
+ elif rounding_mode == "floor" :
1068
+ return impl .elementwise .floor_divide (
1069
+ network ,
1070
+ target ,
1071
+ SourceIR .ATEN ,
1072
+ name ,
1073
+ args [0 ],
1074
+ args [1 ],
1075
+ )
1076
+ elif rounding_mode == "trunc" :
1077
+ return impl .elementwise .trunc_div (
1078
+ network ,
1079
+ target ,
1080
+ SourceIR .ATEN ,
1081
+ name ,
1082
+ args [0 ],
1083
+ args [1 ],
1084
+ )
1085
+ else :
1086
+ raise RuntimeError (
1087
+ f"Target { target } does not support rounding mode { rounding_mode } "
1088
+ )
1091
1089
1092
1090
1093
1091
@dynamo_tensorrt_converter (torch .ops .aten .pow .Tensor_Tensor )
1092
+ @dynamo_tensorrt_converter (torch .ops .aten .pow .Scalar )
1093
+ @dynamo_tensorrt_converter (torch .ops .aten .pow .Tensor_Scalar )
1094
1094
def aten_ops_pow (
1095
1095
network : TRTNetwork ,
1096
1096
target : Target ,
@@ -1109,6 +1109,7 @@ def aten_ops_pow(
1109
1109
1110
1110
1111
1111
@dynamo_tensorrt_converter (torch .ops .aten .floor_divide .default )
1112
+ @dynamo_tensorrt_converter (torch .ops .aten .floor_divide .Scalar )
1112
1113
def aten_ops_floor_div (
1113
1114
network : TRTNetwork ,
1114
1115
target : Target ,
@@ -1181,6 +1182,7 @@ def aten_ops_logical_xor(
1181
1182
1182
1183
1183
1184
@dynamo_tensorrt_converter (torch .ops .aten .eq .Tensor )
1185
+ @dynamo_tensorrt_converter (torch .ops .aten .eq .Scalar )
1184
1186
def aten_ops_equal (
1185
1187
network : TRTNetwork ,
1186
1188
target : Target ,
@@ -1199,6 +1201,7 @@ def aten_ops_equal(
1199
1201
1200
1202
1201
1203
@dynamo_tensorrt_converter (torch .ops .aten .gt .Tensor )
1204
+ @dynamo_tensorrt_converter (torch .ops .aten .gt .Scalar )
1202
1205
def aten_ops_greater (
1203
1206
network : TRTNetwork ,
1204
1207
target : Target ,
@@ -1217,6 +1220,7 @@ def aten_ops_greater(
1217
1220
1218
1221
1219
1222
@dynamo_tensorrt_converter (torch .ops .aten .lt .Tensor )
1223
+ @dynamo_tensorrt_converter (torch .ops .aten .lt .Scalar )
1220
1224
def aten_ops_less (
1221
1225
network : TRTNetwork ,
1222
1226
target : Target ,
0 commit comments