Skip to content

Commit 5a73f50

Browse files
committed
Update quantize.py to use torchao Quantizers
Summary: Remove duplicate code for Int4WeightOnlyQuantizer and Int8DynActInt4WeightQuantizer and use torchao API. Test Plan: ``` python torchchat.py generate llama2 --quantize '{"linear:int4": {"groupsize": 256}, "precision": {"dtype":"float16"}, "executor":{"accelerator":"cpu"}}' --prompt "Once upon a time," --max-new-tokens 256 python torchchat.py generate llama2 --quantize '{"linear:a8w4dq": {"groupsize": 256}, "precision": {"dtype":"float16"}, "executor":{"accelerator":"cpu"}}' --prompt "Once upon a time," --max-new-tokens 256 ``` Reviewers: Subscribers: Tasks: Tags:
1 parent 9ea78b8 commit 5a73f50

File tree

3 files changed

+39
-185
lines changed

3 files changed

+39
-185
lines changed

install_requirements.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ fi
6464
# pip packages needed by exir.
6565
REQUIREMENTS_TO_INSTALL=(
6666
torch=="2.5.0.${NIGHTLY_VERSION}"
67+
torchao-nightly=="2024.6.29"
6768
)
6869

6970
# Install the requirements. `--extra-index-url` tells pip to look for package

quantization/quantize.py

Lines changed: 38 additions & 182 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,11 @@
2323
state_dict_device,
2424
)
2525

26-
from quantization.qops import (
27-
LinearAct8Int4DQ,
28-
LinearInt4 as WeightOnlyInt4Linear,
29-
LinearInt8 as WeightOnlyInt8Linear,
30-
QuantizedEmbedding,
26+
from quantization.qops import LinearAct8Int4DQ, QuantizedEmbedding
27+
from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa
28+
from torchao.quantization.GPTQ import (
29+
Int4WeightOnlyQuantizer,
30+
Int8DynActInt4WeightQuantizer,
3131
)
3232

3333

@@ -50,12 +50,35 @@ def quantize_model(model: nn.Module, device, quantize_options, tokenizer=None):
5050
quantize_options = json.loads(quantize_options)
5151

5252
for quantizer, q_kwargs in quantize_options.items():
53-
if quantizer not in quantizer_class_dict:
53+
if (
54+
quantizer not in quantizer_class_dict
55+
and quantizer not in ao_quantizer_class_dict
56+
):
5457
raise RuntimeError(f"unknown quantizer {quantizer} specified")
55-
56-
model = quantizer_class_dict[quantizer](
57-
model, device=device, tokenizer=tokenizer, **q_kwargs
58-
).quantized_model()
58+
if quantizer in ao_quantizer_class_dict:
59+
dtype = quantize_options.get("precision", {}).get("dtype", "float16")
60+
precision = name_to_dtype(dtype, device)
61+
try:
62+
# Easier to ask forgiveness than permission
63+
quant_handler = ao_quantizer_class_dict[quantizer](
64+
groupsize=q_kwargs["groupsize"], device=device, precision=precision
65+
)
66+
except TypeError as e:
67+
if "unexpected keyword argument 'device'" in str(e):
68+
quant_handler = ao_quantizer_class_dict[quantizer](
69+
groupsize=q_kwargs["groupsize"], precision=precision
70+
)
71+
elif "unexpected keyword argument 'precision'" in str(e):
72+
quant_handler = ao_quantizer_class_dict[quantizer](
73+
groupsize=q_kwargs["groupsize"], device=device
74+
)
75+
else:
76+
raise e
77+
model = quant_handler.quantize(model)
78+
else:
79+
model = quantizer_class_dict[quantizer](
80+
model, device=device, tokenizer=tokenizer, **q_kwargs
81+
).quantized_model()
5982

6083

6184
#########################################################################
@@ -509,176 +532,6 @@ def quantized_model(self) -> nn.Module:
509532
return self.quantize(self.model_)
510533

511534

512-
#########################################################################
513-
##### weight only int4 per channel groupwise quantized code ######
514-
515-
516-
class WeightOnlyInt4QuantHandler(QuantHandler):
517-
def __init__(
518-
self,
519-
model: nn.Module,
520-
device=None,
521-
*,
522-
tokenizer=None,
523-
groupsize=128,
524-
inner_k_tiles=8,
525-
padding_allowed=True,
526-
):
527-
self.model_ = model
528-
self.device = device
529-
self.groupsize = groupsize
530-
self.inner_k_tiles = inner_k_tiles
531-
self.padding_allowed = padding_allowed
532-
assert groupsize in [32, 64, 128, 256]
533-
assert inner_k_tiles in [2, 4, 8]
534-
535-
@torch.no_grad()
536-
def quantize(self, module):
537-
for name, child in module.named_children():
538-
# print(f"name: {name}")
539-
if isinstance(child, torch.nn.Linear):
540-
assert not child.bias
541-
out_features = child.out_features
542-
in_features = child.in_features
543-
assert out_features % 8 == 0, "require out_features % 8 == 0"
544-
# print(f"linear: {fqn}, in={in_features}, out={out_features}")
545-
546-
weight = child.weight.data
547-
if not WeightOnlyInt4Linear._check_k(
548-
k=in_features,
549-
groupsize=self.groupsize,
550-
inner_k_tiles=self.inner_k_tiles,
551-
):
552-
if self.padding_allowed:
553-
# print(
554-
# f"warning: {name} is padded to satisfy in_features % 1024 == 0"
555-
# )
556-
padded_in_features = find_multiple(in_features, 1024)
557-
weight = F.pad(
558-
weight, pad=(0, padded_in_features - in_features)
559-
)
560-
else:
561-
print(
562-
f"warning: {name} is skipped, int4 requires that in_features is 32, 64, or is divisible by 1024, "
563-
+ "and that groupsize and inner_k_tiles*16 evenly divide into it"
564-
)
565-
continue
566-
weight_int4pack, scales_and_zeros = (
567-
WeightOnlyInt4Linear._prepare_weight_and_scales_and_zeros(
568-
weight.to(torch.float), self.groupsize, self.inner_k_tiles
569-
)
570-
)
571-
weight_int4pack = weight_int4pack.to(device=self.device)
572-
scales_and_zeros = scales_and_zeros.to(device=self.device)
573-
574-
setattr(
575-
module,
576-
name,
577-
WeightOnlyInt4Linear(
578-
child.in_features,
579-
child.out_features,
580-
bias=False,
581-
device=self.device,
582-
groupsize=self.groupsize,
583-
inner_k_tiles=self.inner_k_tiles,
584-
weight=weight_int4pack,
585-
scales_and_zeros=scales_and_zeros,
586-
),
587-
)
588-
else:
589-
self.quantize(child)
590-
591-
return module
592-
593-
def quantized_model(self) -> nn.Module:
594-
return self.quantize(self.model_)
595-
596-
597-
#########################################################################
598-
##### weight only int4 per channel groupwise quantized code ######
599-
600-
601-
class Int8DynActInt4WeightQuantizer(QuantHandler):
602-
def __init__(
603-
self,
604-
model: nn.Module,
605-
device=None,
606-
dtype=None,
607-
*,
608-
tokenizer=None,
609-
groupsize=128,
610-
padding_allowed=True,
611-
precision=torch.float32,
612-
scales_precision=torch.float32,
613-
):
614-
if dtype is None:
615-
dtype = torch.float32
616-
617-
self.model_ = model
618-
self.device = device
619-
self.dtype = dtype
620-
621-
self.groupsize = groupsize
622-
self.padding_allowed = padding_allowed
623-
self.precision = precision
624-
self.scales_precision = scales_precision
625-
assert groupsize in [32, 64, 128, 256]
626-
627-
@torch.no_grad()
628-
def quantize(self, module):
629-
from torchao.quantization.quant_primitives import (
630-
group_quantize_tensor_symmetric,
631-
)
632-
633-
for name, child in module.named_children():
634-
# print(f"name: {name}")
635-
if isinstance(child, torch.nn.Linear):
636-
out_features = child.out_features
637-
in_features = child.in_features
638-
weight = child.weight.data
639-
assert not child.bias
640-
assert out_features % 8 == 0, "require out_features % 8 == 0"
641-
# print(f"linear: {fqn}, in={in_features}, out={out_features}")
642-
643-
# if self.padding_allowed:
644-
# padding_multiple=max(self.groupsize, 1024)
645-
padding_multiple = self.groupsize
646-
padded_in_features = find_multiple(in_features, padding_multiple)
647-
weight = F.pad(weight, pad=(0, padded_in_features - in_features))
648-
(
649-
weight_int8,
650-
scales,
651-
zeros,
652-
) = group_quantize_tensor_symmetric(
653-
weight.float(),
654-
4, # n_bit
655-
self.groupsize,
656-
self.scales_precision,
657-
)
658-
659-
setattr(
660-
module,
661-
name,
662-
LinearAct8Int4DQ(
663-
child.in_features,
664-
child.out_features,
665-
bias=False,
666-
device=self.device,
667-
dtype=self.dtype,
668-
groupsize=self.groupsize,
669-
weight=weight_int8.to(device=self.device),
670-
scales=scales.to(device=self.device),
671-
),
672-
)
673-
else:
674-
self.quantize(child)
675-
676-
return module
677-
678-
def quantized_model(self) -> nn.Module:
679-
return self.quantize(self.model_)
680-
681-
682535
##########################################################################
683536
### quantization dictionary ###
684537

@@ -688,8 +541,11 @@ def quantized_model(self) -> nn.Module:
688541
quantizer_class_dict = {
689542
"embedding": EmbeddingOnlyQuantHandler,
690543
"linear:int8": WeightOnlyInt8QuantHandler,
691-
"linear:int4": WeightOnlyInt4QuantHandler,
692-
"linear:a8w4dq": Int8DynActInt4WeightQuantizer,
693544
"precision": PrecisionHandler,
694545
"executor": ExecutorHandler,
695546
}
547+
548+
ao_quantizer_class_dict = {
549+
"linear:int4": Int4WeightOnlyQuantizer,
550+
"linear:a8w4dq": Int8DynActInt4WeightQuantizer,
551+
}

requirements.txt

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,5 @@
11
# Requires python >=3.10
22

3-
# PyTorch ecosystem
4-
torchao==0.1
5-
63
# Hugging Face download
74
huggingface_hub
85

0 commit comments

Comments
 (0)