Skip to content

Commit b439edf

Browse files
committed
Fix 3
1 parent 25bfc2a commit b439edf

File tree

1 file changed

+8
-7
lines changed

1 file changed

+8
-7
lines changed

quantization/quantize.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from torchao.quantization.quant_api import (
2424
quantize_,
2525
int4_weight_only,
26+
Int4WeightOnlyQuantizer,
2627
Int8DynActInt4WeightQuantizer,
2728
)
2829

@@ -49,7 +50,6 @@ def quantize_model(model: nn.Module, device, quantize_options, tokenizer=None):
4950
if (
5051
quantizer not in quantizer_class_dict
5152
and quantizer not in ao_quantizer_class_dict
52-
and quantizer not in ao_quant_api_dict
5353
):
5454
raise RuntimeError(f"unknown quantizer {quantizer} specified")
5555
if quantizer in ao_quantizer_class_dict:
@@ -59,6 +59,12 @@ def quantize_model(model: nn.Module, device, quantize_options, tokenizer=None):
5959
precision = name_to_dtype(dtype, device)
6060
else:
6161
precision = get_precision()
62+
63+
# Only use quant API for dtype bf16 and CUDA
64+
if precision == torch.bfloat16 and device == "cuda":
65+
quantize_(model, int4_weight_only(group_size=q_kwargs["groupsize"]))
66+
continue
67+
6268
try:
6369
# Easier to ask forgiveness than permission
6470
quant_handler = ao_quantizer_class_dict[quantizer](
@@ -76,8 +82,6 @@ def quantize_model(model: nn.Module, device, quantize_options, tokenizer=None):
7682
else:
7783
raise e
7884
model = quant_handler.quantize(model)
79-
elif quantizer in ao_quant_api_dict:
80-
quantize_(model, ao_quant_api_dict[quantizer](group_size=q_kwargs["groupsize"]))
8185
else:
8286
model = quantizer_class_dict[quantizer](
8387
model, device=device, tokenizer=tokenizer, **q_kwargs
@@ -549,9 +553,6 @@ def quantized_model(self) -> nn.Module:
549553
}
550554

551555
ao_quantizer_class_dict = {
556+
"linear:int4": Int4WeightOnlyQuantizer,
552557
"linear:a8w4dq": Int8DynActInt4WeightQuantizer,
553558
}
554-
555-
ao_quant_api_dict = {
556-
"linear:int4": int4_weight_only,
557-
}

0 commit comments

Comments
 (0)