25
25
)
26
26
from qops import LinearInt8 as WeightOnlyInt8Linear , QuantizedEmbedding
27
27
28
+ from qops import (
29
+ LinearInt4 as WeightOnlyInt4Linear ,
30
+ LinearInt8 as WeightOnlyInt8Linear ,
31
+ QuantizedEmbedding ,
32
+ )
33
+
28
34
29
35
#########################################################################
30
36
### torchchat quantization API ###
@@ -606,31 +612,6 @@ def _int4_calc_padded_size(k, groupsize=1, innner_k_tiles=1):
606
612
return find_multiple (k , 1024 )
607
613
608
614
609
- def linear_forward_int4 (x , weight_int4pack , scales_and_zeros , out_features , groupsize ):
610
- origin_x_size = x .size ()
611
- x = x .reshape (- 1 , origin_x_size [- 1 ])
612
-
613
- if "cuda" in str (x .device ):
614
- c = torch .ops .aten ._weight_int4pack_mm (
615
- x .to (torch .bfloat16 ),
616
- weight_int4pack ,
617
- groupsize ,
618
- scales_and_zeros .to (torch .bfloat16 ),
619
- ).to (
620
- x .dtype
621
- ) # cast back to x.dtype
622
- else :
623
- c = torch .ops .aten ._weight_int4pack_mm (
624
- x ,
625
- weight_int4pack ,
626
- groupsize ,
627
- scales_and_zeros ,
628
- )
629
- new_shape = origin_x_size [:- 1 ] + (out_features ,)
630
- c = c .reshape (new_shape )
631
- return c
632
-
633
-
634
615
def replace_linear_int4 (
635
616
module ,
636
617
device ,
@@ -640,9 +621,10 @@ def replace_linear_int4(
640
621
):
641
622
for name , child in module .named_children ():
642
623
if isinstance (child , nn .Linear ):
643
- if (
644
- _check_linear_int4_k (child .in_features , groupsize , inner_k_tiles )
645
- or padding_allowed
624
+ if padding_allowed or WeightOnlyInt4Linear ._check_k (
625
+ k = child .in_features ,
626
+ groupsize = groupsize ,
627
+ inner_k_tiles = inner_k_tiles ,
646
628
):
647
629
setattr (
648
630
module ,
@@ -704,8 +686,10 @@ def create_quantized_state_dict(self):
704
686
# print(f"linear: {fqn}, in={in_features}, out={out_features}")
705
687
706
688
weight = mod .weight .data
707
- if not _check_linear_int4_k (
708
- in_features , self .groupsize , self .inner_k_tiles
689
+ if not WeightOnlyInt4Linear ._check_k (
690
+ k = in_features ,
691
+ groupsize = self .groupsize ,
692
+ inner_k_tiles = self .inner_k_tiles ,
709
693
):
710
694
if self .padding_allowed :
711
695
print (
@@ -751,85 +735,23 @@ def quantized_model(self) -> nn.Module:
751
735
return self .model_
752
736
753
737
754
- class WeightOnlyInt4Linear (torch .nn .Module ):
755
- __constants__ = ["in_features" , "out_features" ]
756
- in_features : int
757
- out_features : int
758
- weight : torch .Tensor
759
- scales_and_zeros : torch .Tensor
760
-
761
- def __init__ (
762
- self ,
763
- device : str ,
764
- in_features : int ,
765
- out_features : int ,
766
- bias = True ,
767
- dtype = None ,
768
- groupsize : int = 128 ,
769
- inner_k_tiles : int = 8 ,
770
- ) -> None :
771
- super ().__init__ ()
772
- self .padding = not _check_linear_int4_k (in_features , groupsize , inner_k_tiles )
773
- if self .padding :
774
- self .origin_in_features = in_features
775
- in_features = find_multiple (in_features , 1024 )
776
-
777
- self .in_features = in_features
778
- self .out_features = out_features
779
- assert not bias , "require bias=False"
780
- self .groupsize = groupsize
781
- self .inner_k_tiles = inner_k_tiles
782
-
783
- assert out_features % 8 == 0 , "require out_features % 8 == 0"
784
- assert (
785
- in_features % (inner_k_tiles * 16 ) == 0
786
- ), "require in_features % (innerKTiles * 16) == 0"
787
- self .register_buffer (
788
- "weight" ,
789
- torch .empty (
790
- (
791
- out_features // 8 ,
792
- in_features // (inner_k_tiles * 16 ),
793
- 32 ,
794
- inner_k_tiles // 2 ,
795
- ),
796
- dtype = torch .int32 ,
797
- device = device ,
798
- ),
799
- )
800
- self .register_buffer (
801
- "scales_and_zeros" ,
802
- torch .empty (
803
- (in_features // groupsize , out_features , 2 ),
804
- dtype = get_precision (),
805
- device = device ,
806
- ),
807
- )
808
-
809
- def forward (self , input : torch .Tensor ) -> torch .Tensor :
810
- if self .padding :
811
- input = F .pad (input , pad = (0 , self .in_features - self .origin_in_features ))
812
- return linear_forward_int4 (
813
- input , self .weight , self .scales_and_zeros , self .out_features , self .groupsize
814
- )
815
-
816
-
817
738
#########################################################################
818
739
##### GPTQ #####
819
740
820
741
821
- def _check_linear_int4_k (k , groupsize = 1 , inner_k_tiles = 1 ):
822
- return k % groupsize == 0 and k % (inner_k_tiles * 16 ) == 0
823
-
824
-
825
742
class GPTQQuantHandler (QuantHandler ):
826
- """
827
- This class implements a GPTQ QuantHandler that can be used to apply GPTQ to a model in concert with the GenericGPTQRunner class.
828
- Unlike the base QuantHandler class, the user does not need to implement the create_quantized_state_dict, instead they have to reimplement
829
- __init__ such that it defines the functions for the quantization mode. User is expected to reimplement convert_for_runtime.
830
-
831
- The following functions (which must be defined in __init__) are used to define the quantization mode for both GPTQ and
832
- create_quantized_state_dict. Here is a description of each function.
743
+ """This class implements a GPTQ QuantHandler that can be used to
744
+ apply GPTQ to a model in concert with the GenericGPTQRunner class.
745
+ Unlike the base QuantHandler class, the user does not need to
746
+ implement the create_quantized_state_dict, instead they have to
747
+ reimplement __init__ such that it defines the functions for the
748
+ quantization mode. User is expected to reimplement
749
+ convert_for_runtime.
750
+
751
+ The following functions (which must be defined in __init__) are
752
+ used to define the quantization mode for both GPTQ and
753
+ create_quantized_state_dict. Here is a description of each
754
+ function.
833
755
834
756
get_qparams_func:
835
757
A function that calculates the quantization qparams for an input tensor.
@@ -839,9 +761,11 @@ class GPTQQuantHandler(QuantHandler):
839
761
qparams: it can have any format but will need to be handled by the other defined functions below.
840
762
841
763
quantize_func:
842
- A function that applies quantization to an input tensor. It should be noted
843
- that this function needs to be able to handle quantizing the entire weight tensor, a single group,
844
- or a single column.
764
+ A function that applies quantization to an input tensor. It
765
+ should be noted that this function needs to be able to handle
766
+ quantizing the entire weight tensor, a single group, or a
767
+ single column.
768
+
845
769
Args:
846
770
weight: A 2d weight tensor with non-integer dtype.
847
771
qparams: the output from get_qparams_func
@@ -850,9 +774,11 @@ class GPTQQuantHandler(QuantHandler):
850
774
851
775
852
776
dequantize_func:
853
- A function that dequantizes an input quantized weight tensor. It should be noted
854
- that this function needs to be able to handle dequantizing the entire weight tensor, a single group,
855
- or a single column.
777
+ A function that dequantizes an input quantized weight
778
+ tensor. It should be noted that this function needs to be able
779
+ to handle dequantizing the entire weight tensor, a single
780
+ group, or a single column.
781
+
856
782
Args:
857
783
quantized_weight: A 2d quantized weight tensor (generally with an integer dtype)
858
784
qparams: the output from get_qparams_func
@@ -861,6 +787,7 @@ class GPTQQuantHandler(QuantHandler):
861
787
862
788
combine_qparams_list_func:
863
789
A function that combines several qparams into one qparam.
790
+
864
791
Args:
865
792
qparams_list: a list of qparams objects, each obtained by calling get_qparams_func
866
793
on a single group from a weight tensor
@@ -875,13 +802,17 @@ class GPTQQuantHandler(QuantHandler):
875
802
skip: boolean indicating whether layer should be skipped
876
803
877
804
make_names_and_values_dict_func:
878
- A function that prepares the qparams and quantized_weight and creates a dictionary indicating how they
879
- should be inserted into the state_dict. Generally any packing of the weight and qparams should be done here.
805
+ A function that prepares the qparams and quantized_weight and
806
+ creates a dictionary indicating how they should be inserted
807
+ into the state_dict. Generally any packing of the weight and
808
+ qparams should be done here.
809
+
880
810
Args:
881
811
quantized_weight: A 2d quantized weight tensor (generally with an integer dtype)
882
812
qparams: the output from get_qparams_func
883
813
Returns:
884
- names_and_values_dict: a dictionary mapping the name of the parameters of the quantized module to the
814
+ names_and_values_dict: a dictionary mapping the name of
815
+ the parameters of the quantized module to the
885
816
corresponding quantized weights and qparams.
886
817
"""
887
818
@@ -1026,14 +957,20 @@ def __init__(
1026
957
]
1027
958
# skip unless padding_allowed=True or its correctly sized
1028
959
self .skip_layer_func = lambda linear_weight : not (
1029
- _check_linear_int4_k (linear_weight .shape [- 1 ], groupsize , inner_k_tiles )
1030
- or padding_allowed
960
+ padding_allowed
961
+ or WeightOnlyInt4Linear ._check_k (
962
+ k = linear_weight .shape [- 1 ],
963
+ groupsize = groupsize ,
964
+ inner_k_tiles = inner_k_tiles ,
965
+ )
1031
966
)
1032
967
1033
968
# we need to do the padding here, both for q and the qparams if necessary
1034
969
def make_names_and_values_dict_func (q , qparams ):
1035
970
k = q .shape [1 ]
1036
- if not _check_linear_int4_k (k , groupsize , inner_k_tiles ):
971
+ if not WeightOnlyInt4Linear ._check_k (
972
+ k = k , groupsize = groupsize , inner_k_tiles = inner_k_tiles
973
+ ):
1037
974
new_k = find_multiple (k , 1024 )
1038
975
else :
1039
976
new_k = k
0 commit comments