Skip to content

Commit 2ff6842

Browse files
committed
Enable GPTQ in executorch
Summary: Previously we just added the code but didn't test it, this PR also tests gptq locally to make sure we can produce a model using gptq from torchao repo Test Plan: python3 -m examples.models.llama2.export_llama -c stories110M.pt -p params.json -qmode 8da4w-gptq Reviewers: Subscribers: Tasks: Tags:
1 parent 3507412 commit 2ff6842

File tree

3 files changed

+8
-605
lines changed

3 files changed

+8
-605
lines changed

examples/models/llama2/export_llama_lib.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@
3535

3636
from .quantize import (
3737
EmbeddingOnlyInt8QuantHandler,
38-
Int8DynActInt4WeightGPTQQuantHandler,
3938
Int8DynActInt4WeightQuantHandler,
4039
WeightOnlyInt8QuantHandler,
4140
)
@@ -217,15 +216,15 @@ def quantize(
217216
print("quantized model:", model_int4)
218217
return model_int4
219218
elif qmode == "8da4w-gptq":
220-
gptq_quant_handler = Int8DynActInt4WeightGPTQQuantHandler(
221-
precision=torch_dtype,
219+
from torchao.quantization.quant_api import (
220+
Int8DynActInt4WeightGPTQQuantizer,
222221
)
223222
tokenizer_path = checkpoint_path.parent / "tokenizer.model"
224223
assert tokenizer_path.is_file(), tokenizer_path
225224
tokenizer = SentencePieceProcessor( # pyre-ignore[28]
226225
model_file=str(tokenizer_path)
227226
)
228-
model_updated_state_dict = gptq_quant_handler.create_quantized_state_dict(
227+
gptq_quantizer = Int8DynActInt4WeightGPTQQuantizer(
229228
tokenizer,
230229
blocksize,
231230
percdamp,
@@ -235,8 +234,7 @@ def quantize(
235234
calibration_seq_length,
236235
pad_calibration_inputs,
237236
)
238-
model = gptq_quant_handler.convert_for_runtime(model)
239-
model.load_state_dict(model_updated_state_dict)
237+
model = gptq_quantizer.quantize(model)
240238
return model
241239
else:
242240
raise Exception(f"Unrecognized quantize mode: {qmode}")
@@ -442,6 +440,7 @@ def _export_llama(modelname, args) -> str: # noqa: C901
442440
# export_to_edge
443441
pt2e_quant_params = _get_pt2e_quantization_params(args)
444442
quantizers = get_pt2e_quantizers(pt2e_quant_params, args)
443+
print("quantizers:", quantizers)
445444

446445
# to_backend
447446
partitioners = {}

examples/models/llama2/llama_transformer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -461,9 +461,9 @@ def forward(
461461
freqs_cos = self.freqs_cos[sp : sp + seqlen]
462462
freqs_sin = self.freqs_sin[sp : sp + seqlen]
463463
else:
464-
assert (
465-
start_pos is None and cache_k is None and cache_v is None
466-
), "Caches and start_pos are unused when use_kv_cache is False"
464+
# assert (
465+
# start_pos is None and cache_k is None and cache_v is None
466+
# ), "Caches and start_pos are unused when use_kv_cache is False"
467467
freqs_cos = self.freqs_cos[:seqlen]
468468
freqs_sin = self.freqs_sin[:seqlen]
469469

0 commit comments

Comments
 (0)