@@ -701,33 +701,14 @@ def forward(self, x, y, z):
701
701
x = x + y + z
702
702
return self .layernorm (x )
703
703
704
- class ConcatBnRelu2d (torch .nn .Module ):
705
- def __init__ (self ):
706
- super (ConcatBnRelu2d , self ).__init__ ()
707
- self .bn = torch .nn .BatchNorm2d (96 )
708
- self .relu = torch .nn .ReLU ()
709
- def forward (self , x1 , x2 , x3 ):
710
- x = torch .cat ((x1 , x2 , x3 ), dim = 1 )
711
- x = self .bn (x )
712
- return self .relu (x )
713
-
714
- class ConcatBnRelu2d_v1 (torch .nn .Module ):
715
- def __init__ (self ):
716
- super (ConcatBnRelu2d_v1 , self ).__init__ ()
717
- self .bn = torch .nn .BatchNorm2d (32 )
718
- self .relu = torch .nn .ReLU ()
719
- def forward (self , x1 , x2 , x3 ):
720
- x = torch .cat ((x1 , x2 , x3 ), dim = 2 )
721
- x = self .bn (x )
722
- return self .relu (x )
723
-
724
- class ConcatBnRelu3d (torch .nn .Module ):
725
- def __init__ (self ):
726
- super (ConcatBnRelu3d , self ).__init__ ()
727
- self .bn = torch .nn .BatchNorm3d (96 )
704
+ class ConcatBnRelu (torch .nn .Module ):
705
+ def __init__ (self , dim , cat_dim , in_channels , ** kwargs ):
706
+ super (ConcatBnRelu , self ).__init__ ()
707
+ self .bn = bn_module [dim ](in_channels )
728
708
self .relu = torch .nn .ReLU ()
709
+ self .cat_dim = cat_dim
729
710
def forward (self , x1 , x2 , x3 ):
730
- x = torch .cat ((x1 , x2 , x3 ), dim = 1 )
711
+ x = torch .cat ((x1 , x2 , x3 ), dim = self . cat_dim )
731
712
x = self .bn (x )
732
713
return self .relu (x )
733
714
@@ -1010,114 +991,50 @@ def test_add_layernorm(self):
1010
991
self .assertTrue (any (n .kind () == node for n in trace_graph .nodes ()))
1011
992
1012
993
def test_concat_bn_relu (self ):
1013
- a1 = torch .randn (1 , 32 , 13 , 24 , dtype = torch .bfloat16 ).contiguous (memory_format = torch .channels_last )
1014
- a2 = torch .randn (1 , 32 , 13 , 24 , dtype = torch .bfloat16 ).contiguous (memory_format = torch .channels_last )
1015
- a3 = torch .randn (1 , 32 , 13 , 24 , dtype = torch .bfloat16 ).contiguous (memory_format = torch .channels_last )
1016
- model = ConcatBnRelu2d ().eval ().to (memory_format = torch .channels_last )
1017
- model = ipex .optimize (model , dtype = torch .bfloat16 , level = 'O0' )
1018
- with torch .no_grad ():
1019
- jit_model = torch .jit .trace (model , (a1 , a2 , a3 )).eval ()
1020
- jit_model = torch .jit .freeze (jit_model )
1021
- #warmup run
1022
- for _ in range (2 ):
1023
- jit_res = jit_model (a1 , a2 , a3 )
1024
- ori_res = model (a1 , a2 , a3 )
1025
- self .assertEqual (jit_res , ori_res )
1026
-
1027
- a1 = torch .randn (1 , 32 , 13 , 24 , dtype = torch .float ).contiguous (memory_format = torch .channels_last )
1028
- a2 = torch .randn (1 , 32 , 13 , 24 , dtype = torch .float ).contiguous (memory_format = torch .channels_last )
1029
- a3 = torch .randn (1 , 32 , 13 , 24 , dtype = torch .float ).contiguous (memory_format = torch .channels_last )
1030
- model = ConcatBnRelu2d_v1 ().eval ().to (memory_format = torch .channels_last )
1031
- model = ipex .optimize (model , dtype = torch .float32 , level = 'O0' )
1032
- with torch .no_grad ():
1033
- jit_model = torch .jit .trace (model , (a1 , a2 , a3 )).eval ()
1034
- jit_model = torch .jit .freeze (jit_model )
1035
- #warmup run
1036
- for _ in range (2 ):
1037
- jit_res = jit_model (a1 , a2 , a3 )
1038
- ori_res = model (a1 , a2 , a3 )
1039
- self .assertEqual (jit_res , ori_res )
1040
-
1041
- model = ConcatBnRelu2d ().eval ().to (memory_format = torch .channels_last )
1042
- model = ipex .optimize (model , dtype = torch .float32 , level = 'O0' )
1043
- with torch .no_grad ():
1044
- jit_model = torch .jit .trace (model , (a1 , a2 , a3 )).eval ()
1045
- jit_model = torch .jit .freeze (jit_model )
1046
- #warmup run
1047
- for _ in range (2 ):
1048
- jit_res = jit_model (a1 , a2 , a3 )
1049
- ori_res = model (a1 , a2 , a3 )
1050
- self .assertEqual (jit_res , ori_res )
1051
-
1052
- a1 = torch .randn (1 , 32 , 18 , 53 , dtype = torch .float ).contiguous (memory_format = torch .channels_last )
1053
- a2 = torch .randn (1 , 32 , 18 , 53 , dtype = torch .float ).contiguous (memory_format = torch .channels_last )
1054
- a3 = torch .randn (1 , 32 , 18 , 53 , dtype = torch .float ).contiguous (memory_format = torch .channels_last )
1055
- with torch .no_grad ():
1056
- jit_res = jit_model (a1 , a2 , a3 )
1057
- ori_res = model (a1 , a2 , a3 )
1058
- self .assertEqual (jit_res , ori_res )
1059
-
1060
- a1 = torch .randn (1 , 16 , 24 , 116 , dtype = torch .float ).contiguous (memory_format = torch .channels_last )
1061
- a2 = torch .randn (1 , 48 , 24 , 116 , dtype = torch .float ).contiguous (memory_format = torch .channels_last )
1062
- a3 = torch .randn (1 , 32 , 24 , 116 , dtype = torch .float ).contiguous (memory_format = torch .channels_last )
1063
- with torch .no_grad ():
1064
- jit_res = jit_model (a1 , a2 , a3 )
1065
- ori_res = model (a1 , a2 , a3 )
1066
- self .assertEqual (jit_res , ori_res )
994
+ batch_size = 3
995
+ image_size = 16
996
+ options = itertools .product ([2 , 3 ], [[32 , 32 , 32 ], [60 , 60 , 60 ], [17 , 27 , 32 ], [16 , 32 , 48 ]], [torch .float32 , torch .bfloat16 ], ['O0' , 'O1' ], [True , False ])
997
+ for dim , channels , dtype , level , use_channels_last in options :
998
+ input_size = [
999
+ [batch_size , channels [0 ], image_size , image_size ],
1000
+ [batch_size , channels [1 ], image_size , image_size ],
1001
+ [batch_size , channels [2 ], image_size , image_size ]
1002
+ ]
1003
+ if dim == 3 :
1004
+ for i in range (3 ):
1005
+ input_size [i ].append (image_size )
1006
+ a1 = torch .randn (input_size [0 ], dtype = dtype )
1007
+ a2 = torch .randn (input_size [1 ], dtype = dtype )
1008
+ a3 = torch .randn (input_size [2 ], dtype = dtype )
1009
+ a = [a1 , a2 , a3 ]
1067
1010
1068
- a1 = torch .randn (1 , 17 , 15 , 24 , dtype = torch .float ).contiguous (memory_format = torch .channels_last )
1069
- a2 = torch .randn (1 , 47 , 15 , 24 , dtype = torch .float ).contiguous (memory_format = torch .channels_last )
1070
- a3 = torch .randn (1 , 32 , 15 , 24 , dtype = torch .float ).contiguous (memory_format = torch .channels_last )
1071
- with torch .no_grad ():
1072
- jit_res = jit_model (a1 , a2 , a3 )
1073
- ori_res = model (a1 , a2 , a3 )
1074
- self .assertEqual (jit_res , ori_res )
1011
+ in_channels = sum (channels )
1012
+ model = ConcatBnRelu (dim , 1 , in_channels ).eval ()
1075
1013
1076
- a1 = torch .randn (1 , 32 , 13 , 24 , dtype = torch .float )
1077
- a2 = torch .randn (1 , 32 , 13 , 24 , dtype = torch .float )
1078
- a3 = torch .randn (1 , 32 , 13 , 24 , dtype = torch .float )
1079
- with torch .no_grad ():
1080
- jit_res = jit_model (a1 , a2 , a3 )
1081
- ori_res = model (a1 , a2 , a3 )
1082
- self .assertEqual (jit_res , ori_res )
1014
+ if use_channels_last :
1015
+ suggest_memory_format = torch .channels_last if dim == 2 else torch .channels_last_3d
1016
+ for i in range (3 ):
1017
+ a [i ] = a [i ].to (memory_format = suggest_memory_format )
1018
+ model = model .to (memory_format = suggest_memory_format )
1083
1019
1084
- a1 = torch .randn (1 , 32 , 13 , 24 , 33 , dtype = torch .float ).contiguous (memory_format = torch .channels_last_3d )
1085
- a2 = torch .randn (1 , 32 , 13 , 24 , 33 , dtype = torch .float ).contiguous (memory_format = torch .channels_last_3d )
1086
- a3 = torch .randn (1 , 32 , 13 , 24 , 33 , dtype = torch .float ).contiguous (memory_format = torch .channels_last_3d )
1087
- model = ConcatBnRelu3d ().eval ().to (memory_format = torch .channels_last_3d )
1088
- model = ipex .optimize (model , dtype = torch .float32 , level = 'O0' )
1089
- with torch .no_grad ():
1090
- jit_model = torch .jit .trace (model , (a1 , a2 , a3 )).eval ()
1091
- jit_model = torch .jit .freeze (jit_model )
1092
- #warmup run
1093
- for _ in range (2 ):
1094
- jit_res = jit_model (a1 , a2 , a3 )
1095
- ori_res = model (a1 , a2 , a3 )
1096
- self .assertEqual (jit_res , ori_res )
1020
+ model = ipex .optimize (model , dtype = dtype , level = level )
1097
1021
1098
- a1 = torch .randn (1 , 16 , 17 , 14 , 31 , dtype = torch .float ).contiguous (memory_format = torch .channels_last_3d )
1099
- a2 = torch .randn (1 , 48 , 17 , 14 , 31 , dtype = torch .float ).contiguous (memory_format = torch .channels_last_3d )
1100
- a3 = torch .randn (1 , 32 , 17 , 14 , 31 , dtype = torch .float ).contiguous (memory_format = torch .channels_last_3d )
1101
- with torch .no_grad ():
1102
- jit_res = jit_model (a1 , a2 , a3 )
1103
- ori_res = model (a1 , a2 , a3 )
1104
- self .assertEqual (jit_res , ori_res )
1022
+ with torch .cpu .amp .autocast (enabled = True if dtype == torch .bfloat16 else False ), torch .no_grad ():
1023
+ result = model (a [0 ], a [1 ], a [2 ])
1024
+ trace_model = torch .jit .trace (model , (a [0 ], a [1 ], a [2 ])).eval ()
1025
+ trace_model = torch .jit .freeze (trace_model )
1105
1026
1106
- a1 = torch .randn (1 , 17 , 13 , 24 , 33 , dtype = torch .float ).contiguous (memory_format = torch .channels_last_3d )
1107
- a2 = torch .randn (1 , 47 , 13 , 24 , 33 , dtype = torch .float ).contiguous (memory_format = torch .channels_last_3d )
1108
- a3 = torch .randn (1 , 32 , 13 , 24 , 33 , dtype = torch .float ).contiguous (memory_format = torch .channels_last_3d )
1109
- with torch .no_grad ():
1110
- jit_res = jit_model (a1 , a2 , a3 )
1111
- ori_res = model (a1 , a2 , a3 )
1112
- self .assertEqual (jit_res , ori_res )
1027
+ tresult = trace_model (a [0 ], a [1 ], a [2 ])
1028
+ trace_graph = trace_model .graph_for (a [0 ], a [1 ], a [2 ])
1113
1029
1114
- a1 = torch .randn (1 , 32 , 13 , 24 , 33 , dtype = torch .float )
1115
- a2 = torch .randn (1 , 32 , 13 , 24 , 33 , dtype = torch .float )
1116
- a3 = torch .randn (1 , 32 , 13 , 24 , 33 , dtype = torch .float )
1117
- with torch .no_grad ():
1118
- jit_res = jit_model (a1 , a2 , a3 )
1119
- ori_res = model (a1 , a2 , a3 )
1120
- self .assertEqual (jit_res , ori_res )
1030
+ self .assertEqual (result , tresult )
1031
+ self .assertEqual (tresult .dtype , dtype )
1032
+ if use_channels_last :
1033
+ self .assertTrue (tresult .is_contiguous (memory_format = suggest_memory_format ))
1034
+ if use_channels_last and a1 .size (1 ) % 16 == 0 and a2 .size (1 ) % 16 == 0 and a3 .size (1 ) % 16 == 0 :
1035
+ self .assertTrue (any (n .kind () == "ipex::concat_bn_relu" for n in trace_graph .nodes ()))
1036
+ else :
1037
+ self .assertTrue (all (n .kind () != "ipex::concat_bn_relu" for n in trace_graph .nodes ()))
1121
1038
1122
1039
def test_mha_scores_calculation (self ):
1123
1040
def _check_match_mha (trace_model , mat1 , mat2 , bias , node = "ipex::mha_scores_calc" ):
0 commit comments