Skip to content

Commit a6e51cf

Browse files
committed
Enable GPTQ in executorch
Summary: Previously we just added the code but didn't test it, this PR also tests gptq locally to make sure we can produce a model using gptq from torchao repo Currently blocked on xnnpack lowering Test Plan: python3 -m examples.models.llama2.export_llama -c stories110M.pt -p params.json -qmode 8da4w-gptq -X Reviewers: Subscribers: Tasks: Tags:
1 parent 39c93aa commit a6e51cf

File tree

2 files changed

+8
-11
lines changed

2 files changed

+8
-11
lines changed

examples/models/llama2/export_llama_lib.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@
3737

3838
from .quantize import (
3939
EmbeddingOnlyInt8QuantHandler,
40-
Int8DynActInt4WeightGPTQQuantHandler,
4140
Int8DynActInt4WeightQuantHandler,
4241
WeightOnlyInt8QuantHandler,
4342
)
@@ -183,7 +182,7 @@ def quantize(
183182
groupsize: int = 128,
184183
# following arguments only used for GPTQ
185184
calibration_tasks: Optional[list] = None,
186-
calibration_limit: int = 1000,
185+
calibration_limit: int = 5,
187186
calibration_seq_length: int = 100,
188187
pad_calibration_inputs: bool = False,
189188
percdamp: float = 0.01,
@@ -206,7 +205,7 @@ def quantize(
206205
checkpoint_path = Path("checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth")
207206

208207
if calibration_tasks is None:
209-
calibration_tasks = ["hellaswag"]
208+
calibration_tasks = ["wikitext"]
210209

211210
if qmode == "int8":
212211
# Add quantization mode options here: group size, bit width, etc.
@@ -219,15 +218,14 @@ def quantize(
219218
print("quantized model:", model_int4)
220219
return model_int4
221220
elif qmode == "8da4w-gptq":
222-
gptq_quant_handler = Int8DynActInt4WeightGPTQQuantHandler(
223-
precision=torch_dtype,
224-
)
221+
from torchao.quantization.quant_api import Int8DynActInt4WeightGPTQQuantizer
222+
225223
tokenizer_path = checkpoint_path.parent / "tokenizer.model"
226224
assert tokenizer_path.is_file(), tokenizer_path
227225
tokenizer = SentencePieceProcessor( # pyre-ignore[28]
228226
model_file=str(tokenizer_path)
229227
)
230-
model_updated_state_dict = gptq_quant_handler.create_quantized_state_dict(
228+
gptq_quantizer = Int8DynActInt4WeightGPTQQuantizer(
231229
tokenizer,
232230
blocksize,
233231
percdamp,
@@ -237,8 +235,7 @@ def quantize(
237235
calibration_seq_length,
238236
pad_calibration_inputs,
239237
)
240-
model = gptq_quant_handler.convert_for_runtime(model)
241-
model.load_state_dict(model_updated_state_dict)
238+
model = gptq_quantizer.quantize(model)
242239
return model
243240
else:
244241
raise Exception(f"Unrecognized quantize mode: {qmode}")

examples/models/llama2/quantize.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import torch.nn.functional as F
1414
from .ops.quantized_ops import * # noqa
1515

16+
# TODO: move to correct place
1617
from torchao.quantization.quant_primitives import (
1718
get_group_qparams_symmetric,
1819
group_quantize_tensor_symmetric,
@@ -652,7 +653,7 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
652653
self.scales,
653654
self.zeros,
654655
self.out_features,
655-
self.groupsize,
656+
self.group_size,
656657
self.precision,
657658
)
658659

@@ -737,7 +738,6 @@ class GPTQQuantHandler(QuantHandler):
737738
"""
738739

739740
def __init__(self):
740-
assert self.mod is not None
741741
assert self.get_qparams_func is not None
742742
assert self.quantize_func is not None
743743
assert self.dequantize_func is not None

0 commit comments

Comments
 (0)