Skip to content

Update quantize.py to use torchao Quantizers #882

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 22 commits into from
Jul 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions .github/workflows/pull.yml
Original file line number Diff line number Diff line change
Expand Up @@ -449,12 +449,13 @@ jobs:
fi
- name: Install requirements
run: |
echo "Intalling pip3 packages"
./install_requirements.sh

# Have to install ET first because deps of Torchchat might not be the same.
export TORCHCHAT_ROOT=$PWD
./scripts/install_et.sh

echo "Intalling pip3 packages"
./install_requirements.sh

pip3 list
python3 -c 'import torch;print(f"torch: {torch.__version__, torch.version.git_version}")'
python3 -c 'import torchvision;print(f"torchvision: {torchvision.__version__, torchvision.version.git_version}")'
Expand Down
4 changes: 2 additions & 2 deletions build/gguf_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
import torch

from gguf import GGUFValueType
from quantization.quantize import pack_scales_and_zeros, WeightOnlyInt4Linear

from quantization.qops import LinearInt4 as WeightOnlyInt4Linear
from quantization.quantize import pack_scales_and_zeros
from build.gguf_util import Q4_0, to_float
from build.model import ModelArgs, Transformer

Expand Down
4 changes: 4 additions & 0 deletions install_requirements.sh
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,10 @@ REQUIREMENTS_TO_INSTALL=(
# versions on the provided URL if they aren't available on the default URL.
$PIP_EXECUTABLE install --extra-index-url "${TORCH_NIGHTLY_URL}" \
"${REQUIREMENTS_TO_INSTALL[@]}"

# For torchao need to install from github since nightly build doesn't have macos build.
# TODO: Remove this and install nightly build, once it supports macos
$PIP_EXECUTABLE install git+https://github.com/pytorch/ao.git@ca1b98db60543a1669a32e842762fc008c178376
if [[ -x "$(command -v nvidia-smi)" ]]; then
$PYTHON_EXECUTABLE scripts/patch_triton.py
fi
131 changes: 40 additions & 91 deletions quantization/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,15 @@
)

from quantization.qops import (
LinearAct8Int4DQ,
LinearInt4 as WeightOnlyInt4Linear,
LinearInt8 as WeightOnlyInt8Linear,
QuantizedEmbedding,
)

# AttributeError: '_OpNamespace' 'quantized_decomposed' object has no attribute 'quantize_per_channel_group'
from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa
from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer


#########################################################################
### torchchat quantization API ###
Expand All @@ -50,12 +53,40 @@ def quantize_model(model: nn.Module, device, quantize_options, tokenizer=None):
quantize_options = json.loads(quantize_options)

for quantizer, q_kwargs in quantize_options.items():
if quantizer not in quantizer_class_dict:
if (
quantizer not in quantizer_class_dict
and quantizer not in ao_quantizer_class_dict
):
raise RuntimeError(f"unknown quantizer {quantizer} specified")
if quantizer in ao_quantizer_class_dict:
# Use dtype precision specified in user config, else fallback on global precision.
if "precision" in quantize_options:
dtype = quantize_options["precision"].get("dtype", str(get_precision()))
precision = name_to_dtype(dtype, device)
else:
precision = get_precision()

model = quantizer_class_dict[quantizer](
model, device=device, tokenizer=tokenizer, **q_kwargs
).quantized_model()
try:
# Easier to ask forgiveness than permission
quant_handler = ao_quantizer_class_dict[quantizer](
groupsize=q_kwargs["groupsize"], device=device, precision=precision
)
except TypeError as e:
if "unexpected keyword argument 'device'" in str(e):
quant_handler = ao_quantizer_class_dict[quantizer](
groupsize=q_kwargs["groupsize"], precision=precision
)
elif "unexpected keyword argument 'precision'" in str(e):
quant_handler = ao_quantizer_class_dict[quantizer](
groupsize=q_kwargs["groupsize"], device=device
)
else:
raise e
model = quant_handler.quantize(model)
else:
model = quantizer_class_dict[quantizer](
model, device=device, tokenizer=tokenizer, **q_kwargs
).quantized_model()


#########################################################################
Expand Down Expand Up @@ -594,91 +625,6 @@ def quantized_model(self) -> nn.Module:
return self.quantize(self.model_)


#########################################################################
##### weight only int4 per channel groupwise quantized code ######


class Int8DynActInt4WeightQuantizer(QuantHandler):
def __init__(
self,
model: nn.Module,
device=None,
dtype=None,
*,
tokenizer=None,
groupsize=128,
padding_allowed=True,
precision=torch.float32,
scales_precision=torch.float32,
):
if dtype is None:
dtype = torch.float32

self.model_ = model
self.device = device
self.dtype = dtype

self.groupsize = groupsize
self.padding_allowed = padding_allowed
self.precision = precision
self.scales_precision = scales_precision
assert groupsize in [32, 64, 128, 256]

@torch.no_grad()
def quantize(self, module):
from torchao.quantization.quant_primitives import (
group_quantize_tensor_symmetric,
)

for name, child in module.named_children():
# print(f"name: {name}")
if isinstance(child, torch.nn.Linear):
out_features = child.out_features
in_features = child.in_features
weight = child.weight.data
assert not child.bias
assert out_features % 8 == 0, "require out_features % 8 == 0"
# print(f"linear: {fqn}, in={in_features}, out={out_features}")

# if self.padding_allowed:
# padding_multiple=max(self.groupsize, 1024)
padding_multiple = self.groupsize
padded_in_features = find_multiple(in_features, padding_multiple)
weight = F.pad(weight, pad=(0, padded_in_features - in_features))
(
weight_int8,
scales,
zeros,
) = group_quantize_tensor_symmetric(
weight.float(),
4, # n_bit
self.groupsize,
self.scales_precision,
)

setattr(
module,
name,
LinearAct8Int4DQ(
child.in_features,
child.out_features,
bias=False,
device=self.device,
dtype=self.dtype,
groupsize=self.groupsize,
weight=weight_int8.to(device=self.device),
scales=scales.to(device=self.device),
),
)
else:
self.quantize(child)

return module

def quantized_model(self) -> nn.Module:
return self.quantize(self.model_)


##########################################################################
### quantization dictionary ###

Expand All @@ -689,7 +635,10 @@ def quantized_model(self) -> nn.Module:
"embedding": EmbeddingOnlyQuantHandler,
"linear:int8": WeightOnlyInt8QuantHandler,
"linear:int4": WeightOnlyInt4QuantHandler,
"linear:a8w4dq": Int8DynActInt4WeightQuantizer,
"precision": PrecisionHandler,
"executor": ExecutorHandler,
}

ao_quantizer_class_dict = {
"linear:a8w4dq": Int8DynActInt4WeightQuantizer,
}
3 changes: 0 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
# Requires python >=3.10

# PyTorch ecosystem
torchao==0.1

# Hugging Face download
huggingface_hub

Expand Down
Loading