@@ -815,6 +815,19 @@ def forward(self, input):
815
815
return x_l
816
816
817
817
818
+ class LinearMulAdd_v2 (nn .Module ):
819
+ def __init__ (self , in_features , out_features ):
820
+ super (LinearMulAdd_v2 , self ).__init__ ()
821
+ self .linear = torch .nn .Linear (in_features , out_features , bias = False )
822
+ self .mul_tensor = torch .tensor (1 )
823
+ self .mul_scalar = 0.5
824
+
825
+ def forward (self , input ):
826
+ x_add = input
827
+ result = self .mul_tensor * self .linear (input ) * self .mul_scalar
828
+ return result + (x_add ).to (result .dtype )
829
+
830
+
818
831
class LinearMul (nn .Module ):
819
832
def __init__ (self , in_features , num_layers , low_rank ):
820
833
super (LinearMul , self ).__init__ ()
@@ -841,6 +854,17 @@ def forward(self, input):
841
854
return x_l
842
855
843
856
857
+ class LinearMul_v2 (nn .Module ):
858
+ def __init__ (self , in_features , out_features ):
859
+ super (LinearMul_v2 , self ).__init__ ()
860
+ self .linear = torch .nn .Linear (in_features , out_features , bias = False )
861
+ self .mul_tensor = torch .tensor (1 )
862
+ self .mul_scalar = 0.5
863
+
864
+ def forward (self , input ):
865
+ return self .mul_scalar * self .linear (input ) * self .mul_tensor
866
+
867
+
844
868
class Linear_Reshape_Relu (nn .Module ):
845
869
def __init__ (self , in_channels , out_channels , dest_shape , ** kwargs ):
846
870
super (Linear_Reshape_Relu , self ).__init__ ()
@@ -4536,6 +4560,25 @@ def test_output_linear_mul_add(self):
4536
4560
prec = 5e-2 ,
4537
4561
)
4538
4562
4563
+ def test_output_linear_mul_add_v2 (self ):
4564
+ m = LinearMulAdd_v2 (4 , 4 )
4565
+ x = torch .ones (2 , 4 )
4566
+ self ._test_output (
4567
+ m ,
4568
+ x ,
4569
+ kind_in_graph = "aten::linear" ,
4570
+ kind_not_in_graph = "ipex_prepack::linear_mul_add_run" ,
4571
+ )
4572
+ self ._test_mkl_fp32 (m , x , kind_in_graph = "ipex_prepack::mkl_sgemm_run" )
4573
+ self ._test_dnnl_fp32 (m , x , kind_in_graph = "ipex_prepack::linear_run" )
4574
+ self ._test_output_lowp (
4575
+ m ,
4576
+ x ,
4577
+ kind_in_graph = "ipex_prepack::linear_run" ,
4578
+ kind_not_in_graph = "ipex_prepack::linear_mul_add_run" ,
4579
+ prec = 5e-2 ,
4580
+ )
4581
+
4539
4582
def test_output_linear_mul (self ):
4540
4583
m = LinearMul (4 , 2 , 8 )
4541
4584
x = torch .ones (2 , 4 )
@@ -4549,6 +4592,25 @@ def test_output_linear_mul(self):
4549
4592
prec = 5e-2 ,
4550
4593
)
4551
4594
4595
+ def test_output_linear_mul_v2 (self ):
4596
+ m = LinearMul_v2 (4 , 4 )
4597
+ x = torch .ones (2 , 4 )
4598
+ self ._test_output (
4599
+ m ,
4600
+ x ,
4601
+ kind_in_graph = "aten::linear" ,
4602
+ kind_not_in_graph = "ipex_prepack::linear_mul_run" ,
4603
+ )
4604
+ self ._test_mkl_fp32 (m , x , kind_in_graph = "ipex_prepack::mkl_sgemm_run" )
4605
+ self ._test_dnnl_fp32 (m , x , kind_in_graph = "ipex_prepack::linear_run" )
4606
+ self ._test_output_lowp (
4607
+ m ,
4608
+ x ,
4609
+ kind_in_graph = "ipex_prepack::linear_run" ,
4610
+ kind_not_in_graph = "ipex_prepack::linear_mul_run" ,
4611
+ prec = 5e-2 ,
4612
+ )
4613
+
4552
4614
def test_output_linear_reshape_relu (self ):
4553
4615
self ._test_output (
4554
4616
Linear_Reshape_Relu (3 , 32 , (64 , 16 ), bias = True ),
0 commit comments