Skip to content

add 16a4w_hqq quant mode #3752

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/models/llama2/export_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def build_args_parser() -> argparse.ArgumentParser:
"--quantization_mode",
type=str,
default=None,
choices=["int8", "8da4w", "8da4w-gptq"],
choices=["int8", "8da4w", "8da4w-gptq", "16a4w-hqq"],
help="type of quantization",
)

Expand Down
205 changes: 205 additions & 0 deletions examples/models/llama2/source_transformation/hqq_16a4w.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe you want to move this to torchao: https://github.com/pytorch/ao/tree/main/torchao/prototype/hqq?

# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch
from executorch.examples.models.llama2.evaluate import EagerEvalWrapper, evaluate_model
from hqq.core.quantize import BaseQuantizeConfig, HQQLinear

########################## Run HQQ ###############################


def _replace_linear_4w_hqq(
module: torch.nn.Module,
quant_config,
compute_dtype,
del_orig=False,
):
"""
Recursively replacing all Linear layers with HQQLinear with the 4bit quantized weights
"""
for name, child in module.named_children():
if isinstance(child, torch.nn.Linear):
new_linear = HQQLinear(
child,
quant_config,
compute_dtype=compute_dtype,
del_orig=True,
device="cpu",
)
setattr(module, name, new_linear)
else:
_replace_linear_4w_hqq(
child,
quant_config,
compute_dtype,
del_orig=False,
)


def replace_linear_4w_hqq(
module: torch.nn.Module,
quant_config: BaseQuantizeConfig,
compute_dtype,
del_orig=False,
):
"""
Replace all Linear layers with HQQLinear with the 4bit quantized weights
"""
_replace_linear_4w_hqq(
module,
quant_config,
compute_dtype,
del_orig=False,
)


def run_hqq_quantize(model: torch.nn.Module) -> None:
"""
Inplace update the model with the hqq quantized weights
"""

quant_config = BaseQuantizeConfig(
quant_zero=False, quant_scale=False, offload_meta=False, view_as_float=False
)

replace_linear_4w_hqq(model, quant_config=quant_config, compute_dtype=torch.float32)


########################## Use static quantization with HQQ Linear ###############################


def calibrate(
model, tokenizer, calibration_tasks, calibration_limit, calibration_seq_length
):
print("run calibration...")
eval_wrapper = EagerEvalWrapper(
model=model,
tokenizer=tokenizer,
max_seq_length=calibration_seq_length,
use_kv_cache=False,
)
eval_results = evaluate_model(
eval_wrapper,
tasks=calibration_tasks,
limit=calibration_limit,
)
for task, res in eval_results["results"].items():
print(f"Reference result with hqq model: {task}: {res}")


class LinearActivationFakeQuant(torch.nn.Module):
def __init__(self, linear):
super().__init__()
self.linear = linear
self.input_activation_fake_quant = torch.quantization.FakeQuantize(
observer=torch.quantization.MovingAverageMinMaxObserver,
dtype=torch.int32,
quant_min=torch.iinfo(torch.uint16).min,
quant_max=torch.iinfo(torch.uint16).max,
)
self.output_activation_fake_quant = torch.quantization.FakeQuantize(
observer=torch.quantization.MovingAverageMinMaxObserver,
dtype=torch.int32,
quant_min=torch.iinfo(torch.uint16).min,
quant_max=torch.iinfo(torch.uint16).max,
)

def forward(self, x):
x = self.input_activation_fake_quant(x)
return self.output_activation_fake_quant(self.linear(x))


def get_quant_params(activation_fake_quant):
quant_min = activation_fake_quant.quant_min
quant_max = activation_fake_quant.quant_max
qparams = activation_fake_quant.calculate_qparams()
scale = qparams[0]
zero_point = qparams[1]
return (quant_min, quant_max, scale, zero_point)


class LinearActivationQuant(torch.nn.Module):

def __init__(self, linear_fake_quant):
super().__init__()
self.linear_fake_quant = linear_fake_quant
(
self.input_quant_min,
self.input_quant_max,
self.input_scale,
self.input_zero_point,
) = get_quant_params(linear_fake_quant.input_activation_fake_quant)

(
self.output_quant_min,
self.output_quant_max,
self.output_scale,
self.output_zero_point,
) = get_quant_params(linear_fake_quant.output_activation_fake_quant)

def forward(self, x):
# Manually quantize the input tensor using observed min and max values
q_tensor = torch.round(x / self.input_scale + self.input_zero_point)
# Clip to ensure within the range [quant min and quant max]
q_tensor = torch.clamp(q_tensor, self.input_quant_min, self.input_quant_max)
# Dequantize to the original scale
dequantized_tensor = (q_tensor - self.input_zero_point) * self.input_scale

linear_output = self.linear_fake_quant.linear(dequantized_tensor)

# # Quantize the linear output tensor
q_linear_output = torch.round(
linear_output / self.output_scale + self.output_zero_point
)
q_linear_output = torch.clamp(
q_linear_output, self.output_quant_min, self.output_quant_max
)
# Dequantize the linear output tensor
dq_linear_output = (
q_linear_output - self.output_zero_point
) * self.output_scale

return dq_linear_output


def _replace_linear_quant_activation(module: torch.nn.Module, stage: str):
for name, child in module.named_children():
if stage == "convert":
if isinstance(child, LinearActivationFakeQuant):
new_linear = LinearActivationQuant(child)
setattr(module, name, new_linear)
else:
_replace_linear_quant_activation(child, stage)
elif stage == "prepare":
if isinstance(child, HQQLinear):
new_linear = LinearActivationFakeQuant(child)
setattr(module, name, new_linear)
else:
_replace_linear_quant_activation(child, stage)
else:
raise ValueError(f"Unsupported stage {stage}")


def replace_linear_quant_activation(module: torch.nn.Module, stage: str):
_replace_linear_quant_activation(
module,
stage,
)


def prepare(model):
"""
Prepare the model for quantization by manually inserting the observors
"""
replace_linear_quant_activation(model, "prepare")


def convert(model):
"""
Convert the observors the actual quant/dequant nodes, in this implementation, we manually
calling add, mul, clamp for quick prototyping
"""
replace_linear_quant_activation(model, "convert")
39 changes: 39 additions & 0 deletions examples/models/llama2/source_transformation/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from executorch.examples.models.llama2.tokenizer.tokenizer import Tokenizer

from sentencepiece import SentencePieceProcessor

Expand Down Expand Up @@ -127,6 +128,44 @@ def quantize(
group_size,
)
model = gptq_quantizer.quantize(model, inputs)
return model
elif qmode == "16a4w-hqq":
try:
from executorch.examples.models.llama2.source_transformation import (
hqq_16a4w,
)
except ImportError:
print(
"Please follow instruction in https://github.com/mobiusml/hqq to install the latest version."
)
if calibration_tasks is None:
calibration_tasks = ["wikitext"]
if calibration_limit is None:
calibration_limit = 5
if calibration_seq_length is None:
calibration_seq_length = 128
if tokenizer_path is None:
tokenizer_path = checkpoint_path.parent / "tokenizer.model"
assert tokenizer_path.is_file(), tokenizer_path
tokenizer = Tokenizer(model_path=str(tokenizer_path)) # pyre-ignore[28]

# Step 1: Run hqq quantization, the linear inside the model will be replaced with HQQ linear
hqq_16a4w.run_hqq_quantize(model)

# Run hqq quantization first
# Insert observer
hqq_16a4w.prepare(model)
# Calibration
hqq_16a4w.calibrate(
model=model,
tokenizer=tokenizer,
calibration_tasks=calibration_tasks,
calibration_limit=calibration_limit,
calibration_seq_length=calibration_seq_length,
)
# Convert observer to the fake quantized model
hqq_16a4w.convert(model)

return model
else:
raise Exception(f"Unrecognized quantize mode: {qmode}")
Expand Down
Loading