2
2
import torch_tensorrt
3
3
from parameterized import parameterized
4
4
from torch .testing ._internal .common_utils import TestCase , run_tests
5
- from parameterized import parameterized
6
5
7
6
from ..testing_utilities import DECIMALS_OF_AGREEMENT , lower_graph_testing
8
7
@@ -963,37 +962,60 @@ def forward(self, input):
963
962
f"The optimized model results shape and torch model results shape should be equal in empty_stride" ,
964
963
)
965
964
966
-
967
- class TestScatterAdd (TestCase ):
968
965
@parameterized .expand (
969
966
[
970
967
(
971
968
"scatter_add_zero_dim_indexOne_constant" ,
972
969
0 ,
973
- torch .tensor ([[0 , 1 , 2 , 0 ]]),
974
- torch .tensor ([[1 , 2 , 3 , 4 ]], dtype = torch .int32 ),
970
+ torch .tensor ([[0 , 1 , 2 , 0 ]]).cuda (),
971
+ torch .tensor ([[1 , 2 , 3 , 4 ]], dtype = torch .int32 ).cuda (),
972
+ {torch .ops .aten .add .Tensor },
975
973
),
976
974
(
977
975
"scatter_add_zero_dim_indexTwo_constant" ,
978
976
0 ,
979
- torch .tensor ([[0 , 1 , 2 , 0 ], [1 , 2 , 1 , 1 ]]),
980
- torch .tensor ([[1 , 2 , 3 , 4 ], [5 , 6 , 7 , 8 ]], dtype = torch .int32 ),
977
+ torch .tensor ([[0 , 1 , 2 , 0 ], [1 , 2 , 1 , 1 ]]).cuda (),
978
+ torch .tensor ([[1 , 2 , 3 , 4 ], [5 , 6 , 7 , 8 ]], dtype = torch .int32 ).cuda (),
979
+ {torch .ops .aten .add .Tensor , torch .ops .aten .scatter .src },
981
980
),
982
981
(
983
982
"scatter_add_one_dim_indexOne_constant" ,
984
983
1 ,
985
- torch .tensor ([[0 , 1 , 2 , 0 ]]),
986
- torch .tensor ([[1 , 2 , 3 , 1 ]], dtype = torch .int32 ),
984
+ torch .tensor ([[0 , 1 , 2 , 0 ]]).cuda (),
985
+ torch .tensor ([[1 , 2 , 3 , 1 ]], dtype = torch .int32 ).cuda (),
986
+ {
987
+ torch .ops .aten .add .Tensor ,
988
+ torch .ops .aten .scatter .src ,
989
+ torch .ops .aten .full_like .default ,
990
+ },
991
+ ),
992
+ (
993
+ "scatter_add_one_dim_indexTwo_constant" ,
994
+ 1 ,
995
+ torch .tensor ([[0 , 1 , 2 , 0 ], [1 , 2 , 1 , 1 ]]).cuda (),
996
+ torch .tensor ([[1 , 2 , 3 , 1 ], [5 , 6 , 5 , 5 ]], dtype = torch .int32 ).cuda (),
997
+ {
998
+ torch .ops .aten .add .Tensor ,
999
+ torch .ops .aten .scatter .src ,
1000
+ torch .ops .aten .full_like .default ,
1001
+ },
987
1002
),
988
1003
(
989
- "scatter_add_one_dim_indexTwo_costant " ,
1004
+ "scatter_add_one_dim_indexTwo_constant " ,
990
1005
1 ,
991
- torch .tensor ([[0 , 1 , 2 , 0 ], [1 , 2 , 1 , 1 ]]),
992
- torch .tensor ([[1 , 2 , 3 , 1 ], [5 , 6 , 5 , 5 ]], dtype = torch .int32 ),
1006
+ torch .tensor ([[0 , 1 , 2 , 0 ], [1 , 2 , 1 , 1 ], [3 , 2 , 1 , 2 ]]).cuda (),
1007
+ torch .tensor (
1008
+ [[1 , 2 , 3 , 1 ], [5 , 6 , 5 , 5 ], [2 , 4 , 3 , 2 ]], dtype = torch .int32
1009
+ ).cuda (),
1010
+ {
1011
+ torch .ops .aten .add .Tensor ,
1012
+ torch .ops .aten .scatter .src ,
1013
+ torch .ops .aten .full_like .default ,
1014
+ },
993
1015
),
994
1016
]
995
1017
)
996
- def test_scatter_add (self , _ , dim , index , src ):
1018
+ def test_scatter_add (self , _ , dim , index , src , expected_ops_param ):
997
1019
class TestModule (torch .nn .Module ):
998
1020
def __init__ (self ):
999
1021
super ().__init__ ()
@@ -1002,14 +1024,19 @@ def forward(self, input):
1002
1024
return torch .ops .aten .scatter_add .default (input , dim , index , src )
1003
1025
1004
1026
# Operations expected to be included in the traced graph after decompositions
1005
- expected_ops = {torch .ops .aten .scatter .src }
1027
+ expected_ops = expected_ops_param
1028
+ unexpected_ops = {torch .ops .aten .scatter_add .default }
1006
1029
1007
- input = torch .zeros (3 , 5 , dtype = torch .int32 )
1030
+ input = torch .zeros (3 , 5 , dtype = torch .int32 ). cuda ()
1008
1031
inputs = [input ]
1009
1032
1010
1033
fx_graph = torch .fx .symbolic_trace (TestModule ())
1011
- _ , expected_ops_unseen = lower_graph_testing (
1012
- fx_graph , inputs , expected_ops = expected_ops , min_block_size = 2
1034
+ unexpected_ops_seen , expected_ops_unseen = lower_graph_testing (
1035
+ fx_graph ,
1036
+ inputs ,
1037
+ expected_ops = expected_ops ,
1038
+ unexpected_ops = unexpected_ops ,
1039
+ min_block_size = 2 ,
1013
1040
)
1014
1041
1015
1042
self .assertEquals (
@@ -1018,6 +1045,36 @@ def forward(self, input):
1018
1045
f"The following expected ops were not encountered: { expected_ops_unseen } " ,
1019
1046
)
1020
1047
1048
+ self .assertEquals (
1049
+ len (unexpected_ops_seen ),
1050
+ 0 ,
1051
+ f"The following expected ops were not encountered: { unexpected_ops_seen } " ,
1052
+ )
1053
+
1054
+ torch ._dynamo .reset ()
1055
+
1056
+ # Validate that the results between Torch and Torch-TRT are similar
1057
+ optimized_model = torch_tensorrt .compile (
1058
+ fx_graph ,
1059
+ "torch_compile" ,
1060
+ inputs ,
1061
+ min_block_size = 1 ,
1062
+ truncate_double = True ,
1063
+ pass_through_build_failures = True ,
1064
+ )
1065
+ optimized_model_results = optimized_model (* inputs ).detach ().cpu ()
1066
+ torch_model_results = fx_graph (* inputs ).detach ().cpu ()
1067
+
1068
+ max_diff = float (
1069
+ torch .max (torch .abs (optimized_model_results - torch_model_results ))
1070
+ )
1071
+ self .assertAlmostEqual (
1072
+ max_diff ,
1073
+ 0 ,
1074
+ DECIMALS_OF_AGREEMENT ,
1075
+ f"Scatter_add TRT outputs don't match with the original model." ,
1076
+ )
1077
+
1021
1078
1022
1079
if __name__ == "__main__" :
1023
1080
run_tests ()
0 commit comments