@@ -412,9 +412,10 @@ def __init__(
412
412
)
413
413
self .quant_weight_per_channel ()
414
414
415
- # TODO - change bias dtyoe to arg.dtype
416
415
self .bias = (
417
- torch .nn .Parameter (torch .randn (self .oc ), requires_grad = False )
416
+ torch .nn .Parameter (
417
+ torch .randn (self .oc ).to (self .op_dtype ), requires_grad = False
418
+ )
418
419
if use_bias
419
420
else None
420
421
)
@@ -595,14 +596,14 @@ def fwd_weight_per_channel_group(self) -> torch.Tensor:
595
596
596
597
def forward (self , input : torch .Tensor ) -> torch .Tensor :
597
598
# Input
598
- input = self .fwd_input_per_token (input )
599
+ input = self .fwd_input_per_token (input ). to ( self . op_dtype )
599
600
600
601
# Weights
601
602
w = (
602
603
self .fwd_weight_per_channel_group ()
603
604
if self .w_scales .ndim == 2
604
605
else self .fwd_weight_per_channel ()
605
- )
606
+ ). to ( self . op_dtype )
606
607
assert isinstance (w , torch .Tensor )
607
608
return torch .nn .functional .linear (input , w , self .bias )
608
609
@@ -734,6 +735,38 @@ def test_qd8_fp32_per_token_weight_per_channel_group_int4(self):
734
735
use_bias = use_bias ,
735
736
)
736
737
738
+ def test_qd8_fp16_per_token_weight_per_channel_group_int4 (self ):
739
+ M_sizes = [1 , 2 , 17 , 31 ]
740
+ K_sizes = [8 , 32 , 64 , 128 ]
741
+ bl_sizes = [8 , 16 , 16 , 32 ]
742
+ N_sizes = [2 , 17 , 92 , 128 ]
743
+
744
+ for use_bias in [True , False ]:
745
+ for i , _ in enumerate (M_sizes ):
746
+ M = int (M_sizes [i ])
747
+ K = int (K_sizes [i ])
748
+ N = int (N_sizes [i ])
749
+ bl = int (bl_sizes [i ])
750
+ mod = self .ManualDQLinear (
751
+ input_channels = K ,
752
+ output_channels = N ,
753
+ weight_n_bit = 4 ,
754
+ dtype = torch .float16 ,
755
+ group_size = bl ,
756
+ force_groupwise_quant = True ,
757
+ use_bias = use_bias ,
758
+ )
759
+
760
+ inputs = (torch .randn (1 , M , K , dtype = torch .float16 ),)
761
+ self ._test_manual_dq_linear (
762
+ mod ,
763
+ inputs ,
764
+ weight_groupwise = True ,
765
+ use_bias = use_bias ,
766
+ atol = 0.1 ,
767
+ rtol = 0.1 ,
768
+ )
769
+
737
770
def _test_linear (
738
771
self ,
739
772
make_module ,
0 commit comments