23
23
from torchao .quantization .quant_api import (
24
24
quantize_ ,
25
25
int4_weight_only ,
26
+ Int4WeightOnlyQuantizer ,
26
27
Int8DynActInt4WeightQuantizer ,
27
28
)
28
29
@@ -49,7 +50,6 @@ def quantize_model(model: nn.Module, device, quantize_options, tokenizer=None):
49
50
if (
50
51
quantizer not in quantizer_class_dict
51
52
and quantizer not in ao_quantizer_class_dict
52
- and quantizer not in ao_quant_api_dict
53
53
):
54
54
raise RuntimeError (f"unknown quantizer { quantizer } specified" )
55
55
if quantizer in ao_quantizer_class_dict :
@@ -59,6 +59,12 @@ def quantize_model(model: nn.Module, device, quantize_options, tokenizer=None):
59
59
precision = name_to_dtype (dtype , device )
60
60
else :
61
61
precision = get_precision ()
62
+
63
+ # Only use quant API for dtype bf16 and CUDA
64
+ if precision == torch .bfloat16 and device == "cuda" :
65
+ quantize_ (model , int4_weight_only (group_size = q_kwargs ["groupsize" ]))
66
+ continue
67
+
62
68
try :
63
69
# Easier to ask forgiveness than permission
64
70
quant_handler = ao_quantizer_class_dict [quantizer ](
@@ -76,8 +82,6 @@ def quantize_model(model: nn.Module, device, quantize_options, tokenizer=None):
76
82
else :
77
83
raise e
78
84
model = quant_handler .quantize (model )
79
- elif quantizer in ao_quant_api_dict :
80
- quantize_ (model , ao_quant_api_dict [quantizer ](group_size = q_kwargs ["groupsize" ]))
81
85
else :
82
86
model = quantizer_class_dict [quantizer ](
83
87
model , device = device , tokenizer = tokenizer , ** q_kwargs
@@ -549,9 +553,6 @@ def quantized_model(self) -> nn.Module:
549
553
}
550
554
551
555
ao_quantizer_class_dict = {
556
+ "linear:int4" : Int4WeightOnlyQuantizer ,
552
557
"linear:a8w4dq" : Int8DynActInt4WeightQuantizer ,
553
558
}
554
-
555
- ao_quant_api_dict = {
556
- "linear:int4" : int4_weight_only ,
557
- }
0 commit comments