21
21
get_precision ,
22
22
name_to_dtype ,
23
23
state_dict_device ,
24
- use_et_backend ,
25
24
)
26
25
27
26
from qops import (
@@ -389,8 +388,6 @@ def quantize(self, module):
389
388
# cur_state_dict = state_dict_device(self.model_.state_dict())
390
389
# dict_device = "cpu" # self.device
391
390
392
- device = self .device
393
-
394
391
if self .bitwidth == 4 :
395
392
range_min = - 8
396
393
range_max = 7
@@ -468,11 +465,6 @@ def __init__(
468
465
469
466
@torch .no_grad ()
470
467
def quantize (self , module ):
471
- # cur_state_dict = state_dict_device(self.model_.state_dict())
472
- # dict_device = "cpu" # self.device
473
-
474
- device = self .device
475
-
476
468
if self .bitwidth == 4 :
477
469
range_min = - 8
478
470
range_max = 7
@@ -544,8 +536,7 @@ def quantized_model(self) -> nn.Module:
544
536
#########################################################################
545
537
##### weight only int4 per channel groupwise quantized code ######
546
538
547
-
548
- class NewWeightOnlyInt4QuantHandler (QuantHandler ):
539
+ class WeightOnlyInt4QuantHandler (QuantHandler ):
549
540
def __init__ (
550
541
self ,
551
542
model : nn .Module ,
@@ -568,11 +559,6 @@ def __init__(
568
559
569
560
@torch .no_grad ()
570
561
def quantize (self , module ):
571
- # cur_state_dict = state_dict_device(self.model_.state_dict())
572
- # dict_device = "cpu" # self.device
573
-
574
- device = self .device
575
-
576
562
for name , child in module .named_children ():
577
563
# print(f"name: {name}")
578
564
if isinstance (child , torch .nn .Linear ):
@@ -633,129 +619,6 @@ def quantized_model(self) -> nn.Module:
633
619
return self .quantize (self .model_ )
634
620
635
621
636
- def replace_linear_int4 (
637
- module ,
638
- device ,
639
- groupsize ,
640
- inner_k_tiles ,
641
- padding_allowed ,
642
- ):
643
- for name , child in module .named_children ():
644
- if isinstance (child , nn .Linear ):
645
- if padding_allowed or WeightOnlyInt4Linear ._check_k (
646
- k = child .in_features ,
647
- groupsize = groupsize ,
648
- inner_k_tiles = inner_k_tiles ,
649
- ):
650
- setattr (
651
- module ,
652
- name ,
653
- WeightOnlyInt4Linear (
654
- child .in_features ,
655
- child .out_features ,
656
- bias = False ,
657
- device = device ,
658
- groupsize = groupsize ,
659
- inner_k_tiles = inner_k_tiles ,
660
- ),
661
- )
662
- else :
663
- replace_linear_int4 (
664
- child , device , groupsize , inner_k_tiles , padding_allowed
665
- )
666
-
667
-
668
- class WeightOnlyInt4QuantHandler (QuantHandler ):
669
- def __init__ (
670
- self ,
671
- model : nn .Module ,
672
- device ,
673
- tokenizer = None ,
674
- * ,
675
- groupsize = 128 ,
676
- inner_k_tiles = 8 ,
677
- padding_allowed = True ,
678
- ):
679
- self .model_ = model
680
- self .device = device
681
- self .groupsize = groupsize
682
- self .inner_k_tiles = inner_k_tiles
683
- self .padding_allowed = padding_allowed
684
- assert groupsize in [32 , 64 , 128 , 256 ]
685
- assert inner_k_tiles in [2 , 4 , 8 ]
686
-
687
- # @torch.no_grad()
688
- # def p(self):
689
- # cur_state_dict = state_dict_device(self.model_.state_dict())
690
- # dict_device = "cpu" # self.device
691
- #
692
- # for fqn, mod in self.model_.named_modules():
693
- # if hasattr(mod, "weight"):
694
- # print(f"device={str(mod.weight.data.device)}")
695
-
696
- @torch .no_grad ()
697
- def create_quantized_state_dict (self ):
698
- cur_state_dict = state_dict_device (self .model_ .state_dict ())
699
- dict_device = "cpu" # self.device
700
-
701
- for fqn , mod in self .model_ .named_modules ():
702
- if isinstance (mod , torch .nn .Linear ):
703
- assert not mod .bias
704
- out_features = mod .out_features
705
- in_features = mod .in_features
706
- assert out_features % 8 == 0 , "require out_features % 8 == 0"
707
- # print(f"linear: {fqn}, in={in_features}, out={out_features}")
708
-
709
- weight = mod .weight .data
710
- if not WeightOnlyInt4Linear ._check_k (
711
- k = in_features ,
712
- groupsize = self .groupsize ,
713
- inner_k_tiles = self .inner_k_tiles ,
714
- ):
715
- if self .padding_allowed :
716
- print (
717
- f"warning: { fqn } is padded to satisfy in_features % 1024 == 0"
718
- )
719
- padded_in_features = find_multiple (in_features , 1024 )
720
- weight = F .pad (
721
- weight , pad = (0 , padded_in_features - in_features )
722
- )
723
- else :
724
- print (
725
- f"warning: { fqn } is skipped, int4 requires that in_features is 32, 64, or is divisible by 1024, "
726
- + "and that groupsize and inner_k_tiles*16 evenly divide into it"
727
- )
728
- continue
729
- weight_int4pack , scales_and_zeros = (
730
- WeightOnlyInt4Linear ._prepare_weight_and_scales_and_zeros (
731
- weight .to (torch .float ), self .groupsize , self .inner_k_tiles
732
- )
733
- )
734
- weight_int4pack = weight_int4pack .to (device = dict_device )
735
- scales_and_zeros = scales_and_zeros .to (device = dict_device )
736
- cur_state_dict [f"{ fqn } .weight" ] = weight_int4pack
737
- cur_state_dict [f"{ fqn } .scales_and_zeros" ] = scales_and_zeros
738
-
739
- return cur_state_dict
740
-
741
- def convert_for_runtime (self ):
742
- replace_linear_int4 (
743
- self .model_ ,
744
- self .device ,
745
- self .groupsize ,
746
- self .inner_k_tiles ,
747
- self .padding_allowed ,
748
- )
749
- return self .model_
750
-
751
- def quantized_model (self ) -> nn .Module :
752
- model_updated_state_dict = self .create_quantized_state_dict ()
753
- self .convert_for_runtime ()
754
- self .model_ .load_state_dict (model_updated_state_dict )
755
- # self.p()
756
- return self .model_
757
-
758
-
759
622
#########################################################################
760
623
##### GPTQ #####
761
624
@@ -1011,13 +874,35 @@ def make_names_and_values_dict_func(q, qparams):
1011
874
self .make_names_and_values_dict_func = make_names_and_values_dict_func
1012
875
super ().__init__ ()
1013
876
877
+ def replace_linear_int4 (
878
+ self ,
879
+ module ,
880
+ ):
881
+ for name , child in module .named_children ():
882
+ if isinstance (child , nn .Linear ):
883
+ if self .padding_allowed or WeightOnlyInt4Linear ._check_k (
884
+ k = child .in_features ,
885
+ groupsize = self .groupsize ,
886
+ inner_k_tiles = self .inner_k_tiles ,
887
+ ):
888
+ setattr (
889
+ module ,
890
+ name ,
891
+ WeightOnlyInt4Linear (
892
+ child .in_features ,
893
+ child .out_features ,
894
+ bias = False ,
895
+ device = self .device ,
896
+ groupsize = self .groupsize ,
897
+ inner_k_tiles = self .inner_k_tiles ,
898
+ ),
899
+ )
900
+ else :
901
+ self .replace_linear_int4 (child )
902
+
1014
903
def convert_for_runtime (self ):
1015
- replace_linear_int4 (
904
+ self . replace_linear_int4 (
1016
905
self .model_ ,
1017
- self .device ,
1018
- self .groupsize ,
1019
- self .inner_k_tiles ,
1020
- self .padding_allowed ,
1021
906
)
1022
907
return self .model_
1023
908
@@ -1048,7 +933,8 @@ def __init__(self, model: nn.Module, device, tokenizer=None, *, groupsize):
1048
933
self .device = device
1049
934
self .groupsize = groupsize
1050
935
1051
- def create_quantized_state_dict (self ):
936
+ @torch .no_grad ()
937
+ def quantize (self , module ):
1052
938
from hqq .core .quantize import Quantizer
1053
939
1054
940
for m in self .model_ .modules ():
@@ -1066,20 +952,11 @@ def create_quantized_state_dict(self):
1066
952
)
1067
953
1068
954
return WeightOnlyInt4QuantHandler (
1069
- self .model_ , self .device , groupsize = self .groupsize
1070
- ).create_quantized_state_dict ()
1071
-
1072
- def convert_for_runtime (self ):
1073
- # ALSO: all code must work for CPU, CUDA, MPS
1074
- return WeightOnlyInt4GPTQQuantHandler (
1075
- self .model_ , self .device , tokenizer = None , groupsize = self .groupsize
1076
- ).convert_for_runtime ()
955
+ model = self .model_ , device = self .device , groupsize = self .groupsize
956
+ ).quantize (self .model_ )
1077
957
1078
958
def quantized_model (self ) -> nn .Module :
1079
- model_updated_state_dict = self .create_quantized_state_dict ()
1080
- self .convert_for_runtime ()
1081
- self .model_ .load_state_dict (model_updated_state_dict )
1082
- return self .model_
959
+ return self .quantize (self .model_ )
1083
960
1084
961
1085
962
##########################################################################
@@ -1091,7 +968,7 @@ def quantized_model(self) -> nn.Module:
1091
968
quantizer_class_dict = {
1092
969
"embedding" : EmbeddingOnlyQuantHandler ,
1093
970
"linear:int8" : WeightOnlyInt8QuantHandler ,
1094
- "linear:int4" : NewWeightOnlyInt4QuantHandler ,
971
+ "linear:int4" : WeightOnlyInt4QuantHandler ,
1095
972
"linear:a8w4dq" : Int8DynActInt4WeightQuantizer ,
1096
973
"linear:int4-gptq" : WeightOnlyInt4GPTQQuantHandler ,
1097
974
"linear:hqq" : WeightOnlyInt4HqqQuantHandler ,
0 commit comments