1
1
import logging
2
2
from typing import Any , Dict , Optional , Sequence , Tuple , Union
3
3
4
- import tensorrt as trt
5
4
import torch
6
5
from torch .fx .node import Argument , Node , Target
7
6
from torch_tensorrt .dynamo ._SourceIR import SourceIR
8
7
from torch_tensorrt .dynamo .conversion import impl
9
- from torch_tensorrt .dynamo .conversion .converter_utils import (
10
- cast_int_int_div_trt_tensor ,
11
- cast_trt_tensor ,
12
- )
13
- from torch_tensorrt .fx .converters import acc_ops_converters
14
8
from torch_tensorrt .fx .types import TRTNetwork , TRTTensor
15
9
16
10
from .converter_registry import dynamo_tensorrt_converter
@@ -48,58 +42,6 @@ def aten_ops_batch_norm(
48
42
)
49
43
50
44
51
- @dynamo_tensorrt_converter (torch .ops .aten .div .default ) # type: ignore[misc]
52
- @dynamo_tensorrt_converter (torch .ops .aten .div .Tensor_mode ) # type: ignore[misc]
53
- @dynamo_tensorrt_converter (torch .ops .aten .div .Tensor ) # type: ignore[misc]
54
- def aten_ops_div (
55
- network : TRTNetwork ,
56
- target : Target ,
57
- args : Tuple [Argument , ...],
58
- kwargs : Dict [str , Argument ],
59
- name : str ,
60
- ) -> Union [TRTTensor , Sequence [TRTTensor ]]:
61
- kwargs_new = {
62
- "input" : args [0 ],
63
- "other" : args [1 ],
64
- }
65
- # If both are TRTTensor, both are cast to float32
66
- if isinstance (args [0 ], TRTTensor ) and isinstance (args [1 ], TRTTensor ):
67
- kwargs_new ["input" ], kwargs_new ["other" ] = cast_int_int_div_trt_tensor (
68
- network ,
69
- kwargs_new ["input" ],
70
- kwargs_new ["other" ],
71
- name ,
72
- )
73
- # If one is TRTTensor, it is cast to float32
74
- elif isinstance (args [0 ], TRTTensor ) and (
75
- kwargs_new ["input" ].dtype == trt .int8 or kwargs_new ["input" ].dtype == trt .int32
76
- ):
77
- kwargs_new ["input" ] = cast_trt_tensor (
78
- network , kwargs_new ["input" ], trt .float32 , name , target
79
- )
80
- elif isinstance (args [1 ], TRTTensor ) and (
81
- kwargs_new ["other" ].dtype == trt .int8 or kwargs_new ["other" ].dtype == trt .int32
82
- ):
83
- kwargs_new ["other" ] = cast_trt_tensor (
84
- network , kwargs_new ["other" ], trt .float32 , name , target
85
- )
86
- rounding_mode = kwargs .get ("rounding_mode" )
87
- if rounding_mode is None :
88
- return acc_ops_converters .acc_ops_div (network , target , None , kwargs_new , name )
89
- elif rounding_mode == "floor" :
90
- return acc_ops_converters .acc_ops_floor_div (
91
- network , target , None , kwargs_new , name
92
- )
93
- elif rounding_mode == "trunc" :
94
- return impl .elementwise .trunc_div (
95
- network , target , SourceIR .ATEN , name , args [0 ], args [1 ]
96
- )
97
- else :
98
- raise RuntimeError (
99
- f"Target { target } does not support rounding mode { rounding_mode } "
100
- )
101
-
102
-
103
45
def embedding_param_validator (embedding_node : Node ) -> bool :
104
46
scale_grad_by_freq = args_bounds_check (embedding_node .args , 3 )
105
47
sparse = args_bounds_check (embedding_node .args , 4 )
@@ -846,24 +788,39 @@ def aten_ops_isinf(
846
788
847
789
848
790
@dynamo_tensorrt_converter (torch .ops .aten .add .Tensor )
791
+ @dynamo_tensorrt_converter (torch .ops .aten .add .Scalar )
849
792
def aten_ops_add (
850
793
network : TRTNetwork ,
851
794
target : Target ,
852
795
args : Tuple [Argument , ...],
853
796
kwargs : Dict [str , Argument ],
854
797
name : str ,
855
798
) -> Union [TRTTensor , Sequence [TRTTensor ]]:
799
+ other = args [1 ]
800
+ alpha = kwargs .get ("alpha" , 1 )
801
+
802
+ if alpha != 1 :
803
+ other = impl .elementwise .mul (
804
+ network ,
805
+ target ,
806
+ SourceIR .ATEN ,
807
+ name ,
808
+ other ,
809
+ alpha ,
810
+ )
811
+
856
812
return impl .elementwise .add (
857
813
network ,
858
814
target ,
859
815
SourceIR .ATEN ,
860
816
name ,
861
817
args [0 ],
862
- args [ 1 ] ,
818
+ other ,
863
819
)
864
820
865
821
866
822
@dynamo_tensorrt_converter (torch .ops .aten .mul .Tensor )
823
+ @dynamo_tensorrt_converter (torch .ops .aten .mul .Scalar )
867
824
def aten_ops_mul (
868
825
network : TRTNetwork ,
869
826
target : Target ,
@@ -918,43 +875,86 @@ def aten_ops_min(
918
875
919
876
920
877
@dynamo_tensorrt_converter (torch .ops .aten .sub .Tensor )
878
+ @dynamo_tensorrt_converter (torch .ops .aten .sub .Scalar )
921
879
def aten_ops_sub (
922
880
network : TRTNetwork ,
923
881
target : Target ,
924
882
args : Tuple [Argument , ...],
925
883
kwargs : Dict [str , Argument ],
926
884
name : str ,
927
885
) -> Union [TRTTensor , Sequence [TRTTensor ]]:
886
+ other = args [1 ]
887
+ alpha = kwargs .get ("alpha" , 1 )
888
+
889
+ if alpha != 1 :
890
+ other = impl .elementwise .mul (
891
+ network ,
892
+ target ,
893
+ SourceIR .ATEN ,
894
+ name ,
895
+ other ,
896
+ alpha ,
897
+ )
898
+
928
899
return impl .elementwise .sub (
929
900
network ,
930
901
target ,
931
902
SourceIR .ATEN ,
932
903
name ,
933
904
args [0 ],
934
- args [ 1 ] ,
905
+ other ,
935
906
)
936
907
937
908
938
- # TODO: keep this or line 54...?
939
- # @dynamo_tensorrt_converter(torch.ops.aten.div.Tensor)
940
- # def aten_ops_div(
941
- # network: TRTNetwork,
942
- # target: Target,
943
- # args: Tuple[Argument, ...],
944
- # kwargs: Dict[str, Argument],
945
- # name: str,
946
- # ) -> Union[TRTTensor, Sequence[TRTTensor]]:
947
- # return impl.elementwise.div(
948
- # network,
949
- # target,
950
- # SourceIR.ATEN,
951
- # name,
952
- # args[0],
953
- # args[1],
954
- # )
909
+ @dynamo_tensorrt_converter (torch .ops .aten .div .Tensor )
910
+ @dynamo_tensorrt_converter (torch .ops .aten .div .Tensor_mode )
911
+ @dynamo_tensorrt_converter (torch .ops .aten .div .Scalar )
912
+ @dynamo_tensorrt_converter (torch .ops .aten .div .Scalar_mode )
913
+ def aten_ops_div (
914
+ network : TRTNetwork ,
915
+ target : Target ,
916
+ args : Tuple [Argument , ...],
917
+ kwargs : Dict [str , Argument ],
918
+ name : str ,
919
+ ) -> Union [TRTTensor , Sequence [TRTTensor ]]:
920
+ rounding_mode = kwargs .get ("rounding_mode" )
921
+
922
+ if rounding_mode is None :
923
+ return impl .elementwise .div (
924
+ network ,
925
+ target ,
926
+ SourceIR .ATEN ,
927
+ name ,
928
+ args [0 ],
929
+ args [1 ],
930
+ )
931
+ elif rounding_mode == "floor" :
932
+ return impl .elementwise .floor_divide (
933
+ network ,
934
+ target ,
935
+ SourceIR .ATEN ,
936
+ name ,
937
+ args [0 ],
938
+ args [1 ],
939
+ )
940
+ elif rounding_mode == "trunc" :
941
+ return impl .elementwise .trunc_div (
942
+ network ,
943
+ target ,
944
+ SourceIR .ATEN ,
945
+ name ,
946
+ args [0 ],
947
+ args [1 ],
948
+ )
949
+ else :
950
+ raise RuntimeError (
951
+ f"Target { target } does not support rounding mode { rounding_mode } "
952
+ )
955
953
956
954
957
955
@dynamo_tensorrt_converter (torch .ops .aten .pow .Tensor_Tensor )
956
+ @dynamo_tensorrt_converter (torch .ops .aten .pow .Scalar )
957
+ @dynamo_tensorrt_converter (torch .ops .aten .pow .Tensor_Scalar )
958
958
def aten_ops_pow (
959
959
network : TRTNetwork ,
960
960
target : Target ,
@@ -973,6 +973,7 @@ def aten_ops_pow(
973
973
974
974
975
975
@dynamo_tensorrt_converter (torch .ops .aten .floor_divide .default )
976
+ @dynamo_tensorrt_converter (torch .ops .aten .floor_divide .Scalar )
976
977
def aten_ops_floor_div (
977
978
network : TRTNetwork ,
978
979
target : Target ,
@@ -1045,6 +1046,7 @@ def aten_ops_logical_xor(
1045
1046
1046
1047
1047
1048
@dynamo_tensorrt_converter (torch .ops .aten .eq .Tensor )
1049
+ @dynamo_tensorrt_converter (torch .ops .aten .eq .Scalar )
1048
1050
def aten_ops_equal (
1049
1051
network : TRTNetwork ,
1050
1052
target : Target ,
@@ -1063,6 +1065,7 @@ def aten_ops_equal(
1063
1065
1064
1066
1065
1067
@dynamo_tensorrt_converter (torch .ops .aten .gt .Tensor )
1068
+ @dynamo_tensorrt_converter (torch .ops .aten .gt .Scalar )
1066
1069
def aten_ops_greater (
1067
1070
network : TRTNetwork ,
1068
1071
target : Target ,
@@ -1081,6 +1084,7 @@ def aten_ops_greater(
1081
1084
1082
1085
1083
1086
@dynamo_tensorrt_converter (torch .ops .aten .lt .Tensor )
1087
+ @dynamo_tensorrt_converter (torch .ops .aten .lt .Scalar )
1084
1088
def aten_ops_less (
1085
1089
network : TRTNetwork ,
1086
1090
target : Target ,
0 commit comments