Skip to content

Commit 7503eeb

Browse files
committed
Enable GPTQ in executorch
Summary: Previously we just added the code but didn't test it, this PR also tests gptq locally to make sure we can produce a model using gptq from torchao repo Currently blocked on xnnpack lowering Test Plan: python3 -m examples.models.llama2.export_llama -c stories110M.pt -p params.json -qmode 8da4w-gptq -X Reviewers: Subscribers: Tasks: Tags:
1 parent 39c93aa commit 7503eeb

File tree

3 files changed

+11
-14
lines changed

3 files changed

+11
-14
lines changed

examples/models/llama2/export_llama_lib.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@
3737

3838
from .quantize import (
3939
EmbeddingOnlyInt8QuantHandler,
40-
Int8DynActInt4WeightGPTQQuantHandler,
4140
Int8DynActInt4WeightQuantHandler,
4241
WeightOnlyInt8QuantHandler,
4342
)
@@ -183,7 +182,7 @@ def quantize(
183182
groupsize: int = 128,
184183
# following arguments only used for GPTQ
185184
calibration_tasks: Optional[list] = None,
186-
calibration_limit: int = 1000,
185+
calibration_limit: int = 5,
187186
calibration_seq_length: int = 100,
188187
pad_calibration_inputs: bool = False,
189188
percdamp: float = 0.01,
@@ -206,7 +205,7 @@ def quantize(
206205
checkpoint_path = Path("checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth")
207206

208207
if calibration_tasks is None:
209-
calibration_tasks = ["hellaswag"]
208+
calibration_tasks = ["wikitext"]
210209

211210
if qmode == "int8":
212211
# Add quantization mode options here: group size, bit width, etc.
@@ -219,15 +218,14 @@ def quantize(
219218
print("quantized model:", model_int4)
220219
return model_int4
221220
elif qmode == "8da4w-gptq":
222-
gptq_quant_handler = Int8DynActInt4WeightGPTQQuantHandler(
223-
precision=torch_dtype,
224-
)
221+
from torchao.quantization.quant_api import Int8DynActInt4WeightGPTQQuantizer
222+
225223
tokenizer_path = checkpoint_path.parent / "tokenizer.model"
226224
assert tokenizer_path.is_file(), tokenizer_path
227225
tokenizer = SentencePieceProcessor( # pyre-ignore[28]
228226
model_file=str(tokenizer_path)
229227
)
230-
model_updated_state_dict = gptq_quant_handler.create_quantized_state_dict(
228+
gptq_quantizer = Int8DynActInt4WeightGPTQQuantizer(
231229
tokenizer,
232230
blocksize,
233231
percdamp,
@@ -237,8 +235,7 @@ def quantize(
237235
calibration_seq_length,
238236
pad_calibration_inputs,
239237
)
240-
model = gptq_quant_handler.convert_for_runtime(model)
241-
model.load_state_dict(model_updated_state_dict)
238+
model = gptq_quantizer.quantize(model)
242239
return model
243240
else:
244241
raise Exception(f"Unrecognized quantize mode: {qmode}")

examples/models/llama2/llama_transformer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -464,9 +464,9 @@ def forward(
464464
freqs_cos = self.freqs_cos[sp : sp + seqlen]
465465
freqs_sin = self.freqs_sin[sp : sp + seqlen]
466466
else:
467-
assert (
468-
start_pos is None and cache_k is None and cache_v is None
469-
), "Caches and start_pos are unused when use_kv_cache is False"
467+
# assert (
468+
# start_pos is None and cache_k is None and cache_v is None
469+
# ), "Caches and start_pos are unused when use_kv_cache is False"
470470
freqs_cos = self.freqs_cos[:seqlen]
471471
freqs_sin = self.freqs_sin[:seqlen]
472472

examples/models/llama2/quantize.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import torch.nn.functional as F
1414
from .ops.quantized_ops import * # noqa
1515

16+
# TODO: move to correct place
1617
from torchao.quantization.quant_primitives import (
1718
get_group_qparams_symmetric,
1819
group_quantize_tensor_symmetric,
@@ -652,7 +653,7 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
652653
self.scales,
653654
self.zeros,
654655
self.out_features,
655-
self.groupsize,
656+
self.group_size,
656657
self.precision,
657658
)
658659

@@ -737,7 +738,6 @@ class GPTQQuantHandler(QuantHandler):
737738
"""
738739

739740
def __init__(self):
740-
assert self.mod is not None
741741
assert self.get_qparams_func is not None
742742
assert self.quantize_func is not None
743743
assert self.dequantize_func is not None

0 commit comments

Comments
 (0)