15
15
import torch
16
16
import torch .nn as nn
17
17
import torch .nn .functional as F
18
- from build .utils import get_device_str , get_precision , name_to_dtype , state_dict_device
18
+ from build .utils import (
19
+ find_multiple ,
20
+ get_device_str ,
21
+ get_precision ,
22
+ name_to_dtype ,
23
+ state_dict_device ,
24
+ )
25
+
26
+ from quantization .qops import (
27
+ LinearInt4 as WeightOnlyInt4Linear ,
28
+ LinearInt8 as WeightOnlyInt8Linear ,
29
+ QuantizedEmbedding ,
30
+ )
19
31
20
- from quantization .qops import LinearInt8 as WeightOnlyInt8Linear , QuantizedEmbedding
21
32
# AttributeError: '_OpNamespace' 'quantized_decomposed' object has no attribute 'quantize_per_channel_group'
22
33
from torch .ao .quantization .fx ._decomposed import quantized_decomposed_lib # noqa
23
- from torchao .quantization .quant_api import (
24
- quantize_ ,
25
- int4_weight_only ,
26
- Int4WeightOnlyQuantizer ,
27
- Int8DynActInt4WeightQuantizer ,
28
- )
34
+ from torchao .quantization .quant_api import Int8DynActInt4WeightQuantizer
29
35
30
36
31
37
#########################################################################
@@ -60,12 +66,6 @@ def quantize_model(model: nn.Module, device, quantize_options, tokenizer=None):
60
66
else :
61
67
precision = get_precision ()
62
68
63
- # Only use quant API for dtype bf16 and CUDA
64
- if quantizer == "linear:int4" and precision == torch .bfloat16 and device == "cuda" :
65
- quantize_ (model , int4_weight_only (group_size = q_kwargs ["groupsize" ]))
66
- model .to (device = "cuda" )
67
- continue
68
-
69
69
try :
70
70
# Easier to ask forgiveness than permission
71
71
quant_handler = ao_quantizer_class_dict [quantizer ](
@@ -540,6 +540,91 @@ def quantized_model(self) -> nn.Module:
540
540
return self .quantize (self .model_ )
541
541
542
542
543
+ #########################################################################
544
+ ##### weight only int4 per channel groupwise quantized code ######
545
+
546
+
547
+ class WeightOnlyInt4QuantHandler (QuantHandler ):
548
+ def __init__ (
549
+ self ,
550
+ model : nn .Module ,
551
+ device = None ,
552
+ * ,
553
+ tokenizer = None ,
554
+ groupsize = 128 ,
555
+ inner_k_tiles = 8 ,
556
+ padding_allowed = True ,
557
+ ):
558
+ self .model_ = model
559
+ self .device = device
560
+ self .groupsize = groupsize
561
+ self .inner_k_tiles = inner_k_tiles
562
+ self .padding_allowed = padding_allowed
563
+ assert groupsize in [32 , 64 , 128 , 256 ]
564
+ assert inner_k_tiles in [2 , 4 , 8 ]
565
+
566
+ @torch .no_grad ()
567
+ def quantize (self , module ):
568
+ for name , child in module .named_children ():
569
+ # print(f"name: {name}")
570
+ if isinstance (child , torch .nn .Linear ):
571
+ assert not child .bias
572
+ out_features = child .out_features
573
+ in_features = child .in_features
574
+ assert out_features % 8 == 0 , "require out_features % 8 == 0"
575
+ # print(f"linear: {fqn}, in={in_features}, out={out_features}")
576
+
577
+ weight = child .weight .data
578
+ if not WeightOnlyInt4Linear ._check_k (
579
+ k = in_features ,
580
+ groupsize = self .groupsize ,
581
+ inner_k_tiles = self .inner_k_tiles ,
582
+ ):
583
+ if self .padding_allowed :
584
+ # print(
585
+ # f"warning: {name} is padded to satisfy in_features % 1024 == 0"
586
+ # )
587
+ padded_in_features = find_multiple (in_features , 1024 )
588
+ weight = F .pad (
589
+ weight , pad = (0 , padded_in_features - in_features )
590
+ )
591
+ else :
592
+ print (
593
+ f"warning: { name } is skipped, int4 requires that in_features is 32, 64, or is divisible by 1024, "
594
+ + "and that groupsize and inner_k_tiles*16 evenly divide into it"
595
+ )
596
+ continue
597
+ weight_int4pack , scales_and_zeros = (
598
+ WeightOnlyInt4Linear ._prepare_weight_and_scales_and_zeros (
599
+ weight .to (torch .float ), self .groupsize , self .inner_k_tiles
600
+ )
601
+ )
602
+ weight_int4pack = weight_int4pack .to (device = self .device )
603
+ scales_and_zeros = scales_and_zeros .to (device = self .device )
604
+
605
+ setattr (
606
+ module ,
607
+ name ,
608
+ WeightOnlyInt4Linear (
609
+ child .in_features ,
610
+ child .out_features ,
611
+ bias = False ,
612
+ device = self .device ,
613
+ groupsize = self .groupsize ,
614
+ inner_k_tiles = self .inner_k_tiles ,
615
+ weight = weight_int4pack ,
616
+ scales_and_zeros = scales_and_zeros ,
617
+ ),
618
+ )
619
+ else :
620
+ self .quantize (child )
621
+
622
+ return module
623
+
624
+ def quantized_model (self ) -> nn .Module :
625
+ return self .quantize (self .model_ )
626
+
627
+
543
628
##########################################################################
544
629
### quantization dictionary ###
545
630
@@ -549,11 +634,11 @@ def quantized_model(self) -> nn.Module:
549
634
quantizer_class_dict = {
550
635
"embedding" : EmbeddingOnlyQuantHandler ,
551
636
"linear:int8" : WeightOnlyInt8QuantHandler ,
637
+ "linear:int4" : WeightOnlyInt4QuantHandler ,
552
638
"precision" : PrecisionHandler ,
553
639
"executor" : ExecutorHandler ,
554
640
}
555
641
556
642
ao_quantizer_class_dict = {
557
- "linear:int4" : Int4WeightOnlyQuantizer ,
558
643
"linear:a8w4dq" : Int8DynActInt4WeightQuantizer ,
559
644
}
0 commit comments