Skip to content

Commit e1914fa

Browse files
larryliu0820Jack Zhang
andauthored
Update quantize.py to use torchao Quantizers (#882)
* Update quantize.py to use torchao Quantizers Summary: Remove duplicate code for Int4WeightOnlyQuantizer and Int8DynActInt4WeightQuantizer and use torchao API. Test Plan: ``` python torchchat.py generate llama2 --quantize '{"linear:int4": {"groupsize": 256}, "precision": {"dtype":"float16"}, "executor":{"accelerator":"cpu"}}' --prompt "Once upon a time," --max-new-tokens 256 python torchchat.py generate llama2 --quantize '{"linear:a8w4dq": {"groupsize": 256}, "precision": {"dtype":"float16"}, "executor":{"accelerator":"cpu"}}' --prompt "Once upon a time," --max-new-tokens 256 ``` Reviewers: Subscribers: Tasks: Tags: * Fix import Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * Install torchao from gh * Explain import Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * Fix dependencies * Test ao PR #479 * Update torchao hash * Update torchao pin * Fix scheduler bf16/fp16 mix error * Incorporate torchao changes Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * update hash * Fix GPU CI job Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * More fix * Fix executorch CI job * Use quant api for int4 weight only quantization Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * Fix Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * Fix again Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * Fix 3 * Fix 4 * Try something * debug * Only migrate 8a4w --------- Co-authored-by: Jack Zhang <[email protected]>
1 parent e1fb003 commit e1914fa

File tree

5 files changed

+50
-99
lines changed

5 files changed

+50
-99
lines changed

.github/workflows/pull.yml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -449,12 +449,13 @@ jobs:
449449
fi
450450
- name: Install requirements
451451
run: |
452-
echo "Intalling pip3 packages"
453-
./install_requirements.sh
454-
452+
# Have to install ET first because deps of Torchchat might not be the same.
455453
export TORCHCHAT_ROOT=$PWD
456454
./scripts/install_et.sh
457455
456+
echo "Intalling pip3 packages"
457+
./install_requirements.sh
458+
458459
pip3 list
459460
python3 -c 'import torch;print(f"torch: {torch.__version__, torch.version.git_version}")'
460461
python3 -c 'import torchvision;print(f"torchvision: {torchvision.__version__, torchvision.version.git_version}")'

build/gguf_loader.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@
1414
import torch
1515

1616
from gguf import GGUFValueType
17-
from quantization.quantize import pack_scales_and_zeros, WeightOnlyInt4Linear
18-
17+
from quantization.qops import LinearInt4 as WeightOnlyInt4Linear
18+
from quantization.quantize import pack_scales_and_zeros
1919
from build.gguf_util import Q4_0, to_float
2020
from build.model import ModelArgs, Transformer
2121

install_requirements.sh

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,10 @@ REQUIREMENTS_TO_INSTALL=(
7070
# versions on the provided URL if they aren't available on the default URL.
7171
$PIP_EXECUTABLE install --extra-index-url "${TORCH_NIGHTLY_URL}" \
7272
"${REQUIREMENTS_TO_INSTALL[@]}"
73+
74+
# For torchao need to install from github since nightly build doesn't have macos build.
75+
# TODO: Remove this and install nightly build, once it supports macos
76+
$PIP_EXECUTABLE install git+https://github.com/pytorch/ao.git@ca1b98db60543a1669a32e842762fc008c178376
7377
if [[ -x "$(command -v nvidia-smi)" ]]; then
7478
$PYTHON_EXECUTABLE scripts/patch_triton.py
7579
fi

quantization/quantize.py

Lines changed: 40 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,15 @@
2424
)
2525

2626
from quantization.qops import (
27-
LinearAct8Int4DQ,
2827
LinearInt4 as WeightOnlyInt4Linear,
2928
LinearInt8 as WeightOnlyInt8Linear,
3029
QuantizedEmbedding,
3130
)
3231

32+
# AttributeError: '_OpNamespace' 'quantized_decomposed' object has no attribute 'quantize_per_channel_group'
33+
from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa
34+
from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer
35+
3336

3437
#########################################################################
3538
### torchchat quantization API ###
@@ -50,12 +53,40 @@ def quantize_model(model: nn.Module, device, quantize_options, tokenizer=None):
5053
quantize_options = json.loads(quantize_options)
5154

5255
for quantizer, q_kwargs in quantize_options.items():
53-
if quantizer not in quantizer_class_dict:
56+
if (
57+
quantizer not in quantizer_class_dict
58+
and quantizer not in ao_quantizer_class_dict
59+
):
5460
raise RuntimeError(f"unknown quantizer {quantizer} specified")
61+
if quantizer in ao_quantizer_class_dict:
62+
# Use dtype precision specified in user config, else fallback on global precision.
63+
if "precision" in quantize_options:
64+
dtype = quantize_options["precision"].get("dtype", str(get_precision()))
65+
precision = name_to_dtype(dtype, device)
66+
else:
67+
precision = get_precision()
5568

56-
model = quantizer_class_dict[quantizer](
57-
model, device=device, tokenizer=tokenizer, **q_kwargs
58-
).quantized_model()
69+
try:
70+
# Easier to ask forgiveness than permission
71+
quant_handler = ao_quantizer_class_dict[quantizer](
72+
groupsize=q_kwargs["groupsize"], device=device, precision=precision
73+
)
74+
except TypeError as e:
75+
if "unexpected keyword argument 'device'" in str(e):
76+
quant_handler = ao_quantizer_class_dict[quantizer](
77+
groupsize=q_kwargs["groupsize"], precision=precision
78+
)
79+
elif "unexpected keyword argument 'precision'" in str(e):
80+
quant_handler = ao_quantizer_class_dict[quantizer](
81+
groupsize=q_kwargs["groupsize"], device=device
82+
)
83+
else:
84+
raise e
85+
model = quant_handler.quantize(model)
86+
else:
87+
model = quantizer_class_dict[quantizer](
88+
model, device=device, tokenizer=tokenizer, **q_kwargs
89+
).quantized_model()
5990

6091

6192
#########################################################################
@@ -594,91 +625,6 @@ def quantized_model(self) -> nn.Module:
594625
return self.quantize(self.model_)
595626

596627

597-
#########################################################################
598-
##### weight only int4 per channel groupwise quantized code ######
599-
600-
601-
class Int8DynActInt4WeightQuantizer(QuantHandler):
602-
def __init__(
603-
self,
604-
model: nn.Module,
605-
device=None,
606-
dtype=None,
607-
*,
608-
tokenizer=None,
609-
groupsize=128,
610-
padding_allowed=True,
611-
precision=torch.float32,
612-
scales_precision=torch.float32,
613-
):
614-
if dtype is None:
615-
dtype = torch.float32
616-
617-
self.model_ = model
618-
self.device = device
619-
self.dtype = dtype
620-
621-
self.groupsize = groupsize
622-
self.padding_allowed = padding_allowed
623-
self.precision = precision
624-
self.scales_precision = scales_precision
625-
assert groupsize in [32, 64, 128, 256]
626-
627-
@torch.no_grad()
628-
def quantize(self, module):
629-
from torchao.quantization.quant_primitives import (
630-
group_quantize_tensor_symmetric,
631-
)
632-
633-
for name, child in module.named_children():
634-
# print(f"name: {name}")
635-
if isinstance(child, torch.nn.Linear):
636-
out_features = child.out_features
637-
in_features = child.in_features
638-
weight = child.weight.data
639-
assert not child.bias
640-
assert out_features % 8 == 0, "require out_features % 8 == 0"
641-
# print(f"linear: {fqn}, in={in_features}, out={out_features}")
642-
643-
# if self.padding_allowed:
644-
# padding_multiple=max(self.groupsize, 1024)
645-
padding_multiple = self.groupsize
646-
padded_in_features = find_multiple(in_features, padding_multiple)
647-
weight = F.pad(weight, pad=(0, padded_in_features - in_features))
648-
(
649-
weight_int8,
650-
scales,
651-
zeros,
652-
) = group_quantize_tensor_symmetric(
653-
weight.float(),
654-
4, # n_bit
655-
self.groupsize,
656-
self.scales_precision,
657-
)
658-
659-
setattr(
660-
module,
661-
name,
662-
LinearAct8Int4DQ(
663-
child.in_features,
664-
child.out_features,
665-
bias=False,
666-
device=self.device,
667-
dtype=self.dtype,
668-
groupsize=self.groupsize,
669-
weight=weight_int8.to(device=self.device),
670-
scales=scales.to(device=self.device),
671-
),
672-
)
673-
else:
674-
self.quantize(child)
675-
676-
return module
677-
678-
def quantized_model(self) -> nn.Module:
679-
return self.quantize(self.model_)
680-
681-
682628
##########################################################################
683629
### quantization dictionary ###
684630

@@ -689,7 +635,10 @@ def quantized_model(self) -> nn.Module:
689635
"embedding": EmbeddingOnlyQuantHandler,
690636
"linear:int8": WeightOnlyInt8QuantHandler,
691637
"linear:int4": WeightOnlyInt4QuantHandler,
692-
"linear:a8w4dq": Int8DynActInt4WeightQuantizer,
693638
"precision": PrecisionHandler,
694639
"executor": ExecutorHandler,
695640
}
641+
642+
ao_quantizer_class_dict = {
643+
"linear:a8w4dq": Int8DynActInt4WeightQuantizer,
644+
}

requirements.txt

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,5 @@
11
# Requires python >=3.10
22

3-
# PyTorch ecosystem
4-
torchao==0.1
5-
63
# Hugging Face download
74
huggingface_hub
85

0 commit comments

Comments
 (0)