24
24
)
25
25
26
26
from quantization .qops import (
27
- LinearAct8Int4DQ ,
28
27
LinearInt4 as WeightOnlyInt4Linear ,
29
28
LinearInt8 as WeightOnlyInt8Linear ,
30
29
QuantizedEmbedding ,
31
30
)
32
31
32
+ # AttributeError: '_OpNamespace' 'quantized_decomposed' object has no attribute 'quantize_per_channel_group'
33
+ from torch .ao .quantization .fx ._decomposed import quantized_decomposed_lib # noqa
34
+ from torchao .quantization .quant_api import Int8DynActInt4WeightQuantizer
35
+
33
36
34
37
#########################################################################
35
38
### torchchat quantization API ###
@@ -50,12 +53,40 @@ def quantize_model(model: nn.Module, device, quantize_options, tokenizer=None):
50
53
quantize_options = json .loads (quantize_options )
51
54
52
55
for quantizer , q_kwargs in quantize_options .items ():
53
- if quantizer not in quantizer_class_dict :
56
+ if (
57
+ quantizer not in quantizer_class_dict
58
+ and quantizer not in ao_quantizer_class_dict
59
+ ):
54
60
raise RuntimeError (f"unknown quantizer { quantizer } specified" )
61
+ if quantizer in ao_quantizer_class_dict :
62
+ # Use dtype precision specified in user config, else fallback on global precision.
63
+ if "precision" in quantize_options :
64
+ dtype = quantize_options ["precision" ].get ("dtype" , str (get_precision ()))
65
+ precision = name_to_dtype (dtype , device )
66
+ else :
67
+ precision = get_precision ()
55
68
56
- model = quantizer_class_dict [quantizer ](
57
- model , device = device , tokenizer = tokenizer , ** q_kwargs
58
- ).quantized_model ()
69
+ try :
70
+ # Easier to ask forgiveness than permission
71
+ quant_handler = ao_quantizer_class_dict [quantizer ](
72
+ groupsize = q_kwargs ["groupsize" ], device = device , precision = precision
73
+ )
74
+ except TypeError as e :
75
+ if "unexpected keyword argument 'device'" in str (e ):
76
+ quant_handler = ao_quantizer_class_dict [quantizer ](
77
+ groupsize = q_kwargs ["groupsize" ], precision = precision
78
+ )
79
+ elif "unexpected keyword argument 'precision'" in str (e ):
80
+ quant_handler = ao_quantizer_class_dict [quantizer ](
81
+ groupsize = q_kwargs ["groupsize" ], device = device
82
+ )
83
+ else :
84
+ raise e
85
+ model = quant_handler .quantize (model )
86
+ else :
87
+ model = quantizer_class_dict [quantizer ](
88
+ model , device = device , tokenizer = tokenizer , ** q_kwargs
89
+ ).quantized_model ()
59
90
60
91
61
92
#########################################################################
@@ -594,91 +625,6 @@ def quantized_model(self) -> nn.Module:
594
625
return self .quantize (self .model_ )
595
626
596
627
597
- #########################################################################
598
- ##### weight only int4 per channel groupwise quantized code ######
599
-
600
-
601
- class Int8DynActInt4WeightQuantizer (QuantHandler ):
602
- def __init__ (
603
- self ,
604
- model : nn .Module ,
605
- device = None ,
606
- dtype = None ,
607
- * ,
608
- tokenizer = None ,
609
- groupsize = 128 ,
610
- padding_allowed = True ,
611
- precision = torch .float32 ,
612
- scales_precision = torch .float32 ,
613
- ):
614
- if dtype is None :
615
- dtype = torch .float32
616
-
617
- self .model_ = model
618
- self .device = device
619
- self .dtype = dtype
620
-
621
- self .groupsize = groupsize
622
- self .padding_allowed = padding_allowed
623
- self .precision = precision
624
- self .scales_precision = scales_precision
625
- assert groupsize in [32 , 64 , 128 , 256 ]
626
-
627
- @torch .no_grad ()
628
- def quantize (self , module ):
629
- from torchao .quantization .quant_primitives import (
630
- group_quantize_tensor_symmetric ,
631
- )
632
-
633
- for name , child in module .named_children ():
634
- # print(f"name: {name}")
635
- if isinstance (child , torch .nn .Linear ):
636
- out_features = child .out_features
637
- in_features = child .in_features
638
- weight = child .weight .data
639
- assert not child .bias
640
- assert out_features % 8 == 0 , "require out_features % 8 == 0"
641
- # print(f"linear: {fqn}, in={in_features}, out={out_features}")
642
-
643
- # if self.padding_allowed:
644
- # padding_multiple=max(self.groupsize, 1024)
645
- padding_multiple = self .groupsize
646
- padded_in_features = find_multiple (in_features , padding_multiple )
647
- weight = F .pad (weight , pad = (0 , padded_in_features - in_features ))
648
- (
649
- weight_int8 ,
650
- scales ,
651
- zeros ,
652
- ) = group_quantize_tensor_symmetric (
653
- weight .float (),
654
- 4 , # n_bit
655
- self .groupsize ,
656
- self .scales_precision ,
657
- )
658
-
659
- setattr (
660
- module ,
661
- name ,
662
- LinearAct8Int4DQ (
663
- child .in_features ,
664
- child .out_features ,
665
- bias = False ,
666
- device = self .device ,
667
- dtype = self .dtype ,
668
- groupsize = self .groupsize ,
669
- weight = weight_int8 .to (device = self .device ),
670
- scales = scales .to (device = self .device ),
671
- ),
672
- )
673
- else :
674
- self .quantize (child )
675
-
676
- return module
677
-
678
- def quantized_model (self ) -> nn .Module :
679
- return self .quantize (self .model_ )
680
-
681
-
682
628
##########################################################################
683
629
### quantization dictionary ###
684
630
@@ -689,7 +635,10 @@ def quantized_model(self) -> nn.Module:
689
635
"embedding" : EmbeddingOnlyQuantHandler ,
690
636
"linear:int8" : WeightOnlyInt8QuantHandler ,
691
637
"linear:int4" : WeightOnlyInt4QuantHandler ,
692
- "linear:a8w4dq" : Int8DynActInt4WeightQuantizer ,
693
638
"precision" : PrecisionHandler ,
694
639
"executor" : ExecutorHandler ,
695
640
}
641
+
642
+ ao_quantizer_class_dict = {
643
+ "linear:a8w4dq" : Int8DynActInt4WeightQuantizer ,
644
+ }
0 commit comments