23
23
state_dict_device ,
24
24
)
25
25
26
- from quantization .qops import (
27
- LinearAct8Int4DQ ,
28
- LinearInt4 as WeightOnlyInt4Linear ,
29
- LinearInt8 as WeightOnlyInt8Linear ,
30
- QuantizedEmbedding ,
26
+ from quantization .qops import LinearAct8Int4DQ , QuantizedEmbedding
27
+ from torch . ao . quantization . fx . _decomposed import quantized_decomposed_lib # noqa
28
+ from torchao . quantization . GPTQ import (
29
+ Int4WeightOnlyQuantizer ,
30
+ Int8DynActInt4WeightQuantizer ,
31
31
)
32
32
33
33
@@ -50,12 +50,35 @@ def quantize_model(model: nn.Module, device, quantize_options, tokenizer=None):
50
50
quantize_options = json .loads (quantize_options )
51
51
52
52
for quantizer , q_kwargs in quantize_options .items ():
53
- if quantizer not in quantizer_class_dict :
53
+ if (
54
+ quantizer not in quantizer_class_dict
55
+ and quantizer not in ao_quantizer_class_dict
56
+ ):
54
57
raise RuntimeError (f"unknown quantizer { quantizer } specified" )
55
-
56
- model = quantizer_class_dict [quantizer ](
57
- model , device = device , tokenizer = tokenizer , ** q_kwargs
58
- ).quantized_model ()
58
+ if quantizer in ao_quantizer_class_dict :
59
+ dtype = quantize_options .get ("precision" , {}).get ("dtype" , "float16" )
60
+ precision = name_to_dtype (dtype , device )
61
+ try :
62
+ # Easier to ask forgiveness than permission
63
+ quant_handler = ao_quantizer_class_dict [quantizer ](
64
+ groupsize = q_kwargs ["groupsize" ], device = device , precision = precision
65
+ )
66
+ except TypeError as e :
67
+ if "unexpected keyword argument 'device'" in str (e ):
68
+ quant_handler = ao_quantizer_class_dict [quantizer ](
69
+ groupsize = q_kwargs ["groupsize" ], precision = precision
70
+ )
71
+ elif "unexpected keyword argument 'precision'" in str (e ):
72
+ quant_handler = ao_quantizer_class_dict [quantizer ](
73
+ groupsize = q_kwargs ["groupsize" ], device = device
74
+ )
75
+ else :
76
+ raise e
77
+ model = quant_handler .quantize (model )
78
+ else :
79
+ model = quantizer_class_dict [quantizer ](
80
+ model , device = device , tokenizer = tokenizer , ** q_kwargs
81
+ ).quantized_model ()
59
82
60
83
61
84
#########################################################################
@@ -509,176 +532,6 @@ def quantized_model(self) -> nn.Module:
509
532
return self .quantize (self .model_ )
510
533
511
534
512
- #########################################################################
513
- ##### weight only int4 per channel groupwise quantized code ######
514
-
515
-
516
- class WeightOnlyInt4QuantHandler (QuantHandler ):
517
- def __init__ (
518
- self ,
519
- model : nn .Module ,
520
- device = None ,
521
- * ,
522
- tokenizer = None ,
523
- groupsize = 128 ,
524
- inner_k_tiles = 8 ,
525
- padding_allowed = True ,
526
- ):
527
- self .model_ = model
528
- self .device = device
529
- self .groupsize = groupsize
530
- self .inner_k_tiles = inner_k_tiles
531
- self .padding_allowed = padding_allowed
532
- assert groupsize in [32 , 64 , 128 , 256 ]
533
- assert inner_k_tiles in [2 , 4 , 8 ]
534
-
535
- @torch .no_grad ()
536
- def quantize (self , module ):
537
- for name , child in module .named_children ():
538
- # print(f"name: {name}")
539
- if isinstance (child , torch .nn .Linear ):
540
- assert not child .bias
541
- out_features = child .out_features
542
- in_features = child .in_features
543
- assert out_features % 8 == 0 , "require out_features % 8 == 0"
544
- # print(f"linear: {fqn}, in={in_features}, out={out_features}")
545
-
546
- weight = child .weight .data
547
- if not WeightOnlyInt4Linear ._check_k (
548
- k = in_features ,
549
- groupsize = self .groupsize ,
550
- inner_k_tiles = self .inner_k_tiles ,
551
- ):
552
- if self .padding_allowed :
553
- # print(
554
- # f"warning: {name} is padded to satisfy in_features % 1024 == 0"
555
- # )
556
- padded_in_features = find_multiple (in_features , 1024 )
557
- weight = F .pad (
558
- weight , pad = (0 , padded_in_features - in_features )
559
- )
560
- else :
561
- print (
562
- f"warning: { name } is skipped, int4 requires that in_features is 32, 64, or is divisible by 1024, "
563
- + "and that groupsize and inner_k_tiles*16 evenly divide into it"
564
- )
565
- continue
566
- weight_int4pack , scales_and_zeros = (
567
- WeightOnlyInt4Linear ._prepare_weight_and_scales_and_zeros (
568
- weight .to (torch .float ), self .groupsize , self .inner_k_tiles
569
- )
570
- )
571
- weight_int4pack = weight_int4pack .to (device = self .device )
572
- scales_and_zeros = scales_and_zeros .to (device = self .device )
573
-
574
- setattr (
575
- module ,
576
- name ,
577
- WeightOnlyInt4Linear (
578
- child .in_features ,
579
- child .out_features ,
580
- bias = False ,
581
- device = self .device ,
582
- groupsize = self .groupsize ,
583
- inner_k_tiles = self .inner_k_tiles ,
584
- weight = weight_int4pack ,
585
- scales_and_zeros = scales_and_zeros ,
586
- ),
587
- )
588
- else :
589
- self .quantize (child )
590
-
591
- return module
592
-
593
- def quantized_model (self ) -> nn .Module :
594
- return self .quantize (self .model_ )
595
-
596
-
597
- #########################################################################
598
- ##### weight only int4 per channel groupwise quantized code ######
599
-
600
-
601
- class Int8DynActInt4WeightQuantizer (QuantHandler ):
602
- def __init__ (
603
- self ,
604
- model : nn .Module ,
605
- device = None ,
606
- dtype = None ,
607
- * ,
608
- tokenizer = None ,
609
- groupsize = 128 ,
610
- padding_allowed = True ,
611
- precision = torch .float32 ,
612
- scales_precision = torch .float32 ,
613
- ):
614
- if dtype is None :
615
- dtype = torch .float32
616
-
617
- self .model_ = model
618
- self .device = device
619
- self .dtype = dtype
620
-
621
- self .groupsize = groupsize
622
- self .padding_allowed = padding_allowed
623
- self .precision = precision
624
- self .scales_precision = scales_precision
625
- assert groupsize in [32 , 64 , 128 , 256 ]
626
-
627
- @torch .no_grad ()
628
- def quantize (self , module ):
629
- from torchao .quantization .quant_primitives import (
630
- group_quantize_tensor_symmetric ,
631
- )
632
-
633
- for name , child in module .named_children ():
634
- # print(f"name: {name}")
635
- if isinstance (child , torch .nn .Linear ):
636
- out_features = child .out_features
637
- in_features = child .in_features
638
- weight = child .weight .data
639
- assert not child .bias
640
- assert out_features % 8 == 0 , "require out_features % 8 == 0"
641
- # print(f"linear: {fqn}, in={in_features}, out={out_features}")
642
-
643
- # if self.padding_allowed:
644
- # padding_multiple=max(self.groupsize, 1024)
645
- padding_multiple = self .groupsize
646
- padded_in_features = find_multiple (in_features , padding_multiple )
647
- weight = F .pad (weight , pad = (0 , padded_in_features - in_features ))
648
- (
649
- weight_int8 ,
650
- scales ,
651
- zeros ,
652
- ) = group_quantize_tensor_symmetric (
653
- weight .float (),
654
- 4 , # n_bit
655
- self .groupsize ,
656
- self .scales_precision ,
657
- )
658
-
659
- setattr (
660
- module ,
661
- name ,
662
- LinearAct8Int4DQ (
663
- child .in_features ,
664
- child .out_features ,
665
- bias = False ,
666
- device = self .device ,
667
- dtype = self .dtype ,
668
- groupsize = self .groupsize ,
669
- weight = weight_int8 .to (device = self .device ),
670
- scales = scales .to (device = self .device ),
671
- ),
672
- )
673
- else :
674
- self .quantize (child )
675
-
676
- return module
677
-
678
- def quantized_model (self ) -> nn .Module :
679
- return self .quantize (self .model_ )
680
-
681
-
682
535
##########################################################################
683
536
### quantization dictionary ###
684
537
@@ -688,8 +541,11 @@ def quantized_model(self) -> nn.Module:
688
541
quantizer_class_dict = {
689
542
"embedding" : EmbeddingOnlyQuantHandler ,
690
543
"linear:int8" : WeightOnlyInt8QuantHandler ,
691
- "linear:int4" : WeightOnlyInt4QuantHandler ,
692
- "linear:a8w4dq" : Int8DynActInt4WeightQuantizer ,
693
544
"precision" : PrecisionHandler ,
694
545
"executor" : ExecutorHandler ,
695
546
}
547
+
548
+ ao_quantizer_class_dict = {
549
+ "linear:int4" : Int4WeightOnlyQuantizer ,
550
+ "linear:a8w4dq" : Int8DynActInt4WeightQuantizer ,
551
+ }
0 commit comments