Skip to content

Enable GPTQ in executorch #2425

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 6 additions & 9 deletions examples/models/llama2/export_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@

from .quantize import (
EmbeddingOnlyInt8QuantHandler,
Int8DynActInt4WeightGPTQQuantHandler,
Int8DynActInt4WeightQuantHandler,
WeightOnlyInt8QuantHandler,
)
Expand Down Expand Up @@ -183,7 +182,7 @@ def quantize(
groupsize: int = 128,
# following arguments only used for GPTQ
calibration_tasks: Optional[list] = None,
calibration_limit: int = 1000,
calibration_limit: int = 5,
calibration_seq_length: int = 100,
pad_calibration_inputs: bool = False,
percdamp: float = 0.01,
Expand All @@ -206,7 +205,7 @@ def quantize(
checkpoint_path = Path("checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth")

if calibration_tasks is None:
calibration_tasks = ["hellaswag"]
calibration_tasks = ["wikitext"]

if qmode == "int8":
# Add quantization mode options here: group size, bit width, etc.
Expand All @@ -219,15 +218,14 @@ def quantize(
print("quantized model:", model_int4)
return model_int4
elif qmode == "8da4w-gptq":
gptq_quant_handler = Int8DynActInt4WeightGPTQQuantHandler(
precision=torch_dtype,
)
from torchao.quantization.quant_api import Int8DynActInt4WeightGPTQQuantizer

tokenizer_path = checkpoint_path.parent / "tokenizer.model"
assert tokenizer_path.is_file(), tokenizer_path
tokenizer = SentencePieceProcessor( # pyre-ignore[28]
model_file=str(tokenizer_path)
)
model_updated_state_dict = gptq_quant_handler.create_quantized_state_dict(
gptq_quantizer = Int8DynActInt4WeightGPTQQuantizer(
tokenizer,
blocksize,
percdamp,
Expand All @@ -237,8 +235,7 @@ def quantize(
calibration_seq_length,
pad_calibration_inputs,
)
model = gptq_quant_handler.convert_for_runtime(model)
model.load_state_dict(model_updated_state_dict)
model = gptq_quantizer.quantize(model)
return model
else:
raise Exception(f"Unrecognized quantize mode: {qmode}")
Expand Down
6 changes: 3 additions & 3 deletions examples/models/llama2/llama_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,9 +464,9 @@ def forward(
freqs_cos = self.freqs_cos[sp : sp + seqlen]
freqs_sin = self.freqs_sin[sp : sp + seqlen]
else:
assert (
start_pos is None and cache_k is None and cache_v is None
), "Caches and start_pos are unused when use_kv_cache is False"
# assert (
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kimishpatel is this OK? I need to comment this out to make sure we can run the torch._dyncmo.export in GPTQ.

@HDCharles is going to work on a refactor of GPTQ to remove export and use tensor subclass instead, we can revert this change when that is implemented I think.

# start_pos is None and cache_k is None and cache_v is None
# ), "Caches and start_pos are unused when use_kv_cache is False"
freqs_cos = self.freqs_cos[:seqlen]
freqs_sin = self.freqs_sin[:seqlen]

Expand Down
4 changes: 2 additions & 2 deletions examples/models/llama2/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import torch.nn.functional as F
from .ops.quantized_ops import * # noqa

# TODO: move to correct place
from torchao.quantization.quant_primitives import (
get_group_qparams_symmetric,
group_quantize_tensor_symmetric,
Expand Down Expand Up @@ -652,7 +653,7 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
self.scales,
self.zeros,
self.out_features,
self.groupsize,
self.group_size,
self.precision,
)

Expand Down Expand Up @@ -737,7 +738,6 @@ class GPTQQuantHandler(QuantHandler):
"""

def __init__(self):
assert self.mod is not None
assert self.get_qparams_func is not None
assert self.quantize_func is not None
assert self.dequantize_func is not None
Expand Down