@@ -791,7 +791,14 @@ def _calc_padded_size_linear_int4(k, groupsize=1, inner_k_tiles=1):
791
791
return find_multiple (k , groupsize , inner_k_tiles * 16 )
792
792
793
793
794
- def replace_linear_8da4w (module , group_size , inner_k_tiles , padding_allowed ):
794
+ def replace_linear_8da4w (
795
+ module ,
796
+ group_size ,
797
+ inner_k_tiles ,
798
+ padding_allowed ,
799
+ activation_precision ,
800
+ weight_precision ,
801
+ ):
795
802
for name , child in module .named_children ():
796
803
if isinstance (child , nn .Linear ):
797
804
if (
@@ -807,20 +814,37 @@ def replace_linear_8da4w(module, group_size, inner_k_tiles, padding_allowed):
807
814
bias = False ,
808
815
group_size = group_size ,
809
816
inner_k_tiles = inner_k_tiles ,
817
+ activation_precision = activation_precision ,
818
+ weight_precision = weight_precision ,
810
819
),
811
820
)
812
821
else :
813
- replace_linear_8da4w (child , group_size , inner_k_tiles , padding_allowed )
822
+ replace_linear_8da4w (
823
+ child ,
824
+ group_size ,
825
+ inner_k_tiles ,
826
+ padding_allowed ,
827
+ activation_precision ,
828
+ weight_precision ,
829
+ )
814
830
815
831
816
832
class Int8DynActInt4WeightQuantHandler :
817
- def __init__ (self , mod , group_size = 128 , inner_k_tiles = 8 , padding_allowed = True ):
833
+ def __init__ (
834
+ self ,
835
+ mod ,
836
+ group_size = 128 ,
837
+ inner_k_tiles = 8 ,
838
+ padding_allowed = True ,
839
+ activation_precision = torch .float16 ,
840
+ weight_precision = torch .float16 ,
841
+ ):
818
842
self .mod = mod
819
843
self .group_size = group_size
820
844
self .inner_k_tiles = inner_k_tiles
821
845
self .padding_allowed = padding_allowed
822
- # TODO: make this an argument
823
- self .precision = torch . float16
846
+ self . activation_precision = activation_precision
847
+ self .weight_precision = weight_precision
824
848
assert group_size in [32 , 64 , 128 , 256 ]
825
849
assert inner_k_tiles in [2 , 4 , 8 ]
826
850
@@ -861,7 +885,9 @@ def create_quantized_state_dict(self):
861
885
weight_int4pack ,
862
886
scales_and_zeros ,
863
887
) = prepare_int4_weight_and_scales_and_zeros (
864
- weight .to (self .precision ), self .group_size , self .inner_k_tiles
888
+ weight .to (self .weight_precision ),
889
+ self .group_size ,
890
+ self .inner_k_tiles ,
865
891
)
866
892
cur_state_dict [f"{ fqn } .weight" ] = weight_int4pack .to ("cpu" )
867
893
cur_state_dict [f"{ fqn } .scales_and_zeros" ] = scales_and_zeros .to ("cpu" )
@@ -870,7 +896,12 @@ def create_quantized_state_dict(self):
870
896
871
897
def convert_for_runtime (self ):
872
898
replace_linear_8da4w (
873
- self .mod , self .group_size , self .inner_k_tiles , self .padding_allowed
899
+ self .mod ,
900
+ self .group_size ,
901
+ self .inner_k_tiles ,
902
+ self .padding_allowed ,
903
+ self .activation_precision ,
904
+ self .weight_precision ,
874
905
)
875
906
return self .mod
876
907
@@ -891,6 +922,8 @@ def __init__(
891
922
dtype = None ,
892
923
group_size : int = 128 ,
893
924
inner_k_tiles : int = 8 ,
925
+ activation_precision : torch .dtype = torch .float16 ,
926
+ weight_precision : torch .dtype = torch .float16 ,
894
927
) -> None :
895
928
super ().__init__ ()
896
929
# always pad if needed since it becomes a noop at runtime if not needed
@@ -903,7 +936,8 @@ def __init__(
903
936
assert not bias , "require bias=False"
904
937
self .group_size = group_size
905
938
self .inner_k_tiles = inner_k_tiles
906
- self .precision = torch .float16
939
+ self .weight_precision = weight_precision
940
+ self .activation_precision = activation_precision
907
941
908
942
# assert out_features % 8 == 0, "require out_features % 8 == 0"
909
943
assert (
@@ -917,12 +951,13 @@ def __init__(
917
951
self .register_buffer (
918
952
"scales_and_zeros" ,
919
953
torch .empty (
920
- (in_features // group_size , out_features , 2 ), dtype = self .precision
954
+ (in_features // group_size , out_features , 2 ),
955
+ dtype = self .weight_precision ,
921
956
),
922
957
)
923
958
924
959
def forward (self , input : torch .Tensor ) -> torch .Tensor :
925
- input = input .to (self .precision )
960
+ input = input .to (self .activation_precision )
926
961
input = F .pad (input , pad = (0 , self .in_features - self .origin_in_features ))
927
962
928
963
(
@@ -937,15 +972,21 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
937
972
input , scales , zero_points , quant_min , quant_max , torch .int8
938
973
)
939
974
input = torch .ops .quantized_decomposed .dequantize_per_token (
940
- input , scales , zero_points , quant_min , quant_max , torch .int8 , self .precision
975
+ input ,
976
+ scales ,
977
+ zero_points ,
978
+ quant_min ,
979
+ quant_max ,
980
+ torch .int8 ,
981
+ self .activation_precision ,
941
982
)
942
983
943
- input = input .to (self .precision )
984
+ input = input .to (self .activation_precision )
944
985
return linear_forward_int4 (
945
986
input ,
946
987
self .weight ,
947
988
self .scales_and_zeros ,
948
989
self .out_features ,
949
990
self .group_size ,
950
- self .precision ,
991
+ self .weight_precision ,
951
992
)
0 commit comments