Skip to content

Commit 5a5935e

Browse files
committed
Enable GPTQ in executorch
Summary: Previously we just added the code but didn't test it, this PR also tests gptq locally to make sure we can produce a model using gptq from torchao repo Currently blocked on xnnpack lowering Test Plan: python3 -m examples.models.llama2.export_llama -c stories110M.pt -p params.json -qmode 8da4w-gptq -X Reviewers: Subscribers: Tasks: Tags:
1 parent 3507412 commit 5a5935e

File tree

2 files changed

+13
-517
lines changed

2 files changed

+13
-517
lines changed

examples/models/llama2/export_llama_lib.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@
3535

3636
from .quantize import (
3737
EmbeddingOnlyInt8QuantHandler,
38-
Int8DynActInt4WeightGPTQQuantHandler,
3938
Int8DynActInt4WeightQuantHandler,
4039
WeightOnlyInt8QuantHandler,
4140
)
@@ -181,7 +180,7 @@ def quantize(
181180
groupsize: int = 128,
182181
# following arguments only used for GPTQ
183182
calibration_tasks: Optional[list] = None,
184-
calibration_limit: int = 1000,
183+
calibration_limit: int = 5,
185184
calibration_seq_length: int = 100,
186185
pad_calibration_inputs: bool = False,
187186
percdamp: float = 0.01,
@@ -204,7 +203,7 @@ def quantize(
204203
checkpoint_path = Path("checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth")
205204

206205
if calibration_tasks is None:
207-
calibration_tasks = ["hellaswag"]
206+
calibration_tasks = ["wikitext"]
208207

209208
if qmode == "int8":
210209
# Add quantization mode options here: group size, bit width, etc.
@@ -217,15 +216,14 @@ def quantize(
217216
print("quantized model:", model_int4)
218217
return model_int4
219218
elif qmode == "8da4w-gptq":
220-
gptq_quant_handler = Int8DynActInt4WeightGPTQQuantHandler(
221-
precision=torch_dtype,
222-
)
219+
from torchao.quantization.quant_api import Int8DynActInt4WeightGPTQQuantizer
220+
223221
tokenizer_path = checkpoint_path.parent / "tokenizer.model"
224222
assert tokenizer_path.is_file(), tokenizer_path
225223
tokenizer = SentencePieceProcessor( # pyre-ignore[28]
226224
model_file=str(tokenizer_path)
227225
)
228-
model_updated_state_dict = gptq_quant_handler.create_quantized_state_dict(
226+
gptq_quantizer = Int8DynActInt4WeightGPTQQuantizer(
229227
tokenizer,
230228
blocksize,
231229
percdamp,
@@ -235,8 +233,7 @@ def quantize(
235233
calibration_seq_length,
236234
pad_calibration_inputs,
237235
)
238-
model = gptq_quant_handler.convert_for_runtime(model)
239-
model.load_state_dict(model_updated_state_dict)
236+
model = gptq_quantizer.quantize(model)
240237
return model
241238
else:
242239
raise Exception(f"Unrecognized quantize mode: {qmode}")

0 commit comments

Comments
 (0)