Skip to content

Commit 9b6c568

Browse files
committed
Enable GPTQ in executorch
Summary: Previously we just added the code but didn't test it, this PR also tests gptq locally to make sure we can produce a model using gptq from torchao repo Currently blocked on xnnpack lowering Test Plan: python3 -m examples.models.llama2.export_llama -c stories110M.pt -p params.json -qmode 8da4w-gptq -X Reviewers: Subscribers: Tasks: Tags:
1 parent 3507412 commit 9b6c568

File tree

3 files changed

+17
-519
lines changed

3 files changed

+17
-519
lines changed

examples/models/llama2/export_llama_lib.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@
3535

3636
from .quantize import (
3737
EmbeddingOnlyInt8QuantHandler,
38-
Int8DynActInt4WeightGPTQQuantHandler,
3938
Int8DynActInt4WeightQuantHandler,
4039
WeightOnlyInt8QuantHandler,
4140
)
@@ -181,7 +180,7 @@ def quantize(
181180
groupsize: int = 128,
182181
# following arguments only used for GPTQ
183182
calibration_tasks: Optional[list] = None,
184-
calibration_limit: int = 1000,
183+
calibration_limit: int = 5,
185184
calibration_seq_length: int = 100,
186185
pad_calibration_inputs: bool = False,
187186
percdamp: float = 0.01,
@@ -204,7 +203,7 @@ def quantize(
204203
checkpoint_path = Path("checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth")
205204

206205
if calibration_tasks is None:
207-
calibration_tasks = ["hellaswag"]
206+
calibration_tasks = ["wikitext"]
208207

209208
if qmode == "int8":
210209
# Add quantization mode options here: group size, bit width, etc.
@@ -217,15 +216,14 @@ def quantize(
217216
print("quantized model:", model_int4)
218217
return model_int4
219218
elif qmode == "8da4w-gptq":
220-
gptq_quant_handler = Int8DynActInt4WeightGPTQQuantHandler(
221-
precision=torch_dtype,
222-
)
219+
from torchao.quantization.quant_api import Int8DynActInt4WeightGPTQQuantizer
220+
223221
tokenizer_path = checkpoint_path.parent / "tokenizer.model"
224222
assert tokenizer_path.is_file(), tokenizer_path
225223
tokenizer = SentencePieceProcessor( # pyre-ignore[28]
226224
model_file=str(tokenizer_path)
227225
)
228-
model_updated_state_dict = gptq_quant_handler.create_quantized_state_dict(
226+
gptq_quantizer = Int8DynActInt4WeightGPTQQuantizer(
229227
tokenizer,
230228
blocksize,
231229
percdamp,
@@ -235,8 +233,7 @@ def quantize(
235233
calibration_seq_length,
236234
pad_calibration_inputs,
237235
)
238-
model = gptq_quant_handler.convert_for_runtime(model)
239-
model.load_state_dict(model_updated_state_dict)
236+
model = gptq_quantizer.quantize(model)
240237
return model
241238
else:
242239
raise Exception(f"Unrecognized quantize mode: {qmode}")

examples/models/llama2/llama_transformer.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,9 @@ def forward(
211211
bsz, seqlen, _ = x.shape
212212

213213
# QKV
214+
# TODO: re-enable
214215
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
216+
# xq, xk, xv = x, x, x
215217
# We need view_copy elimination
216218
xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
217219
xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
@@ -461,9 +463,9 @@ def forward(
461463
freqs_cos = self.freqs_cos[sp : sp + seqlen]
462464
freqs_sin = self.freqs_sin[sp : sp + seqlen]
463465
else:
464-
assert (
465-
start_pos is None and cache_k is None and cache_v is None
466-
), "Caches and start_pos are unused when use_kv_cache is False"
466+
# assert (
467+
# start_pos is None and cache_k is None and cache_v is None
468+
# ), "Caches and start_pos are unused when use_kv_cache is False"
467469
freqs_cos = self.freqs_cos[:seqlen]
468470
freqs_sin = self.freqs_sin[:seqlen]
469471

0 commit comments

Comments
 (0)