@@ -54,7 +54,7 @@ def quantize_model(model: nn.Module, device, quantize_options, tokenizer=None):
54
54
raise RuntimeError (f"unknown quantizer { quantizer } specified" )
55
55
56
56
model = quantizer_class_dict [quantizer ](
57
- model , device , tokenizer , ** q_kwargs
57
+ model , device = device , tokenizer = tokenizer , ** q_kwargs
58
58
).quantized_model ()
59
59
60
60
@@ -450,7 +450,7 @@ def quantized_model(self) -> nn.Module:
450
450
##### embedding table quantization ######
451
451
452
452
453
- class EmbeddingOnlyInt8QuantHandler (QuantHandler ):
453
+ class EmbeddingOnlyQuantHandler (QuantHandler ):
454
454
def __init__ (
455
455
self ,
456
456
model : nn .Module ,
@@ -545,6 +545,94 @@ def quantized_model(self) -> nn.Module:
545
545
##### weight only int4 per channel groupwise quantized code ######
546
546
547
547
548
+ class NewWeightOnlyInt4QuantHandler (QuantHandler ):
549
+ def __init__ (
550
+ self ,
551
+ model : nn .Module ,
552
+ device = None ,
553
+ * ,
554
+ tokenizer = None ,
555
+ groupsize = 128 ,
556
+ inner_k_tiles = 8 ,
557
+ padding_allowed = True ,
558
+ weight : Optional [torch .Tensor ] = None ,
559
+ scales_and_zeros : Optional [torch .Tensor ] = None ,
560
+ ):
561
+ self .model_ = model
562
+ self .device = device
563
+ self .groupsize = groupsize
564
+ self .inner_k_tiles = inner_k_tiles
565
+ self .padding_allowed = padding_allowed
566
+ assert groupsize in [32 , 64 , 128 , 256 ]
567
+ assert inner_k_tiles in [2 , 4 , 8 ]
568
+
569
+ @torch .no_grad ()
570
+ def quantize (self , module ):
571
+ # cur_state_dict = state_dict_device(self.model_.state_dict())
572
+ # dict_device = "cpu" # self.device
573
+
574
+ device = self .device
575
+
576
+ for name , child in module .named_children ():
577
+ # print(f"name: {name}")
578
+ if isinstance (child , torch .nn .Linear ):
579
+ assert not child .bias
580
+ out_features = child .out_features
581
+ in_features = child .in_features
582
+ assert out_features % 8 == 0 , "require out_features % 8 == 0"
583
+ # print(f"linear: {fqn}, in={in_features}, out={out_features}")
584
+
585
+ weight = child .weight .data
586
+ if not WeightOnlyInt4Linear ._check_k (
587
+ k = in_features ,
588
+ groupsize = self .groupsize ,
589
+ inner_k_tiles = self .inner_k_tiles ,
590
+ ):
591
+ if self .padding_allowed :
592
+ print (
593
+ f"warning: { name } is padded to satisfy in_features % 1024 == 0"
594
+ )
595
+ padded_in_features = find_multiple (in_features , 1024 )
596
+ weight = F .pad (
597
+ weight , pad = (0 , padded_in_features - in_features )
598
+ )
599
+ else :
600
+ print (
601
+ f"warning: { name } is skipped, int4 requires that in_features is 32, 64, or is divisible by 1024, "
602
+ + "and that groupsize and inner_k_tiles*16 evenly divide into it"
603
+ )
604
+ continue
605
+ weight_int4pack , scales_and_zeros = (
606
+ WeightOnlyInt4Linear ._prepare_weight_and_scales_and_zeros (
607
+ weight .to (torch .float ), self .groupsize , self .inner_k_tiles
608
+ )
609
+ )
610
+ weight_int4pack = weight_int4pack .to (device = self .device )
611
+ scales_and_zeros = scales_and_zeros .to (device = self .device )
612
+
613
+ setattr (
614
+ module ,
615
+ name ,
616
+ WeightOnlyInt4Linear (
617
+ child .in_features ,
618
+ child .out_features ,
619
+ bias = False ,
620
+ device = self .device ,
621
+ groupsize = self .groupsize ,
622
+ inner_k_tiles = self .inner_k_tiles ,
623
+ weight = weight_int4pack ,
624
+ scales_and_zeros = scales_and_zeros ,
625
+ ),
626
+ )
627
+ else :
628
+ self .quantize (child )
629
+
630
+ return module
631
+
632
+ def quantized_model (self ) -> nn .Module :
633
+ return self .quantize (self .model_ )
634
+
635
+
548
636
def replace_linear_int4 (
549
637
module ,
550
638
device ,
@@ -563,10 +651,10 @@ def replace_linear_int4(
563
651
module ,
564
652
name ,
565
653
WeightOnlyInt4Linear (
566
- device ,
567
654
child .in_features ,
568
655
child .out_features ,
569
656
bias = False ,
657
+ device = device ,
570
658
groupsize = groupsize ,
571
659
inner_k_tiles = inner_k_tiles ,
572
660
),
@@ -1001,9 +1089,9 @@ def quantized_model(self) -> nn.Module:
1001
1089
# Must come last because __future__ annotations don't work for naked
1002
1090
# class references
1003
1091
quantizer_class_dict = {
1004
- "embedding" : EmbeddingOnlyInt8QuantHandler ,
1092
+ "embedding" : EmbeddingOnlyQuantHandler ,
1005
1093
"linear:int8" : WeightOnlyInt8QuantHandler ,
1006
- "linear:int4" : WeightOnlyInt4QuantHandler ,
1094
+ "linear:int4" : NewWeightOnlyInt4QuantHandler ,
1007
1095
"linear:a8w4dq" : Int8DynActInt4WeightQuantizer ,
1008
1096
"linear:int4-gptq" : WeightOnlyInt4GPTQQuantHandler ,
1009
1097
"linear:hqq" : WeightOnlyInt4HqqQuantHandler ,
0 commit comments