@@ -761,6 +761,18 @@ def forward(self, x):
761
761
return torch .add (self .linear (x ), self .linear1 (x1 ))
762
762
763
763
764
+ class LinearAdd2 (nn .Module ):
765
+ def __init__ (self , in_channels , out_channels , ** kwargs ):
766
+ super (LinearAdd2 , self ).__init__ ()
767
+ seed = 2018
768
+ torch .manual_seed (seed )
769
+ self .linear = nn .Linear (in_channels , out_channels , ** kwargs )
770
+
771
+ def forward (self , x ):
772
+ y = x .clone ().unsqueeze (0 ).permute (2 , 1 , 0 , 3 ).squeeze (0 )
773
+ return self .linear (x ) + y
774
+
775
+
764
776
class LinearAddRelu (nn .Module ):
765
777
def __init__ (self , in_channels , mid_channels , out_channels , inplace , ** kwargs ):
766
778
super (LinearAddRelu , self ).__init__ ()
@@ -4480,6 +4492,11 @@ def test_output_linear_add(self):
4480
4492
torch .rand (32 , 3 ),
4481
4493
kind_in_graph = "ipex_prepack::linear_add_run" ,
4482
4494
)
4495
+ self ._test_dnnl_fp32 (
4496
+ LinearAdd2 (3 , 3 , bias = False ),
4497
+ torch .rand (3 , 1 , 3 ),
4498
+ kind_in_graph = "ipex_prepack::linear_add_run" ,
4499
+ )
4483
4500
self ._test_output_lowp (
4484
4501
LinearAdd (3 , 32 , bias = True ),
4485
4502
torch .rand (32 , 3 ),
0 commit comments