Skip to content

Commit 80ad94b

Browse files
committed
[quant] Use Int8DynActInt4WeightQuantizer in torchao
Summary: att Test Plan: python3 -m examples.models.llama2.export_llama -c stories110M.pt -p params.json -qmode 8da4w -X -d fp32 Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 2fc1355 Pull Request resolved: #2551
1 parent f9cad4e commit 80ad94b

File tree

1 file changed

+6
-8
lines changed

1 file changed

+6
-8
lines changed

examples/models/llama2/export_llama_lib.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -212,11 +212,9 @@ def quantize(
212212
if qmode == "int8":
213213
# Add quantization mode options here: group size, bit width, etc.
214214
return WeightOnlyInt8QuantHandler(model).quantized_model()
215-
elif qmode == "int4":
216-
model_int4 = Int8DynActInt4WeightQuantHandler(
217-
model,
218-
precision=torch_dtype,
219-
).quantized_model()
215+
elif qmode == "8da4w":
216+
from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer
217+
model_int4 = Int8DynActInt4WeightQuantizer(precision=torch_dtype).quantize(model)
220218
print("quantized model:", model_int4)
221219
return model_int4
222220
elif qmode == "8da4w-gptq":
@@ -287,7 +285,7 @@ def build_args_parser() -> argparse.ArgumentParser:
287285
"--quantization_mode",
288286
type=str,
289287
default=None,
290-
choices=["int8", "int4", "8da4w-gptq"],
288+
choices=["int8", "8da4w", "8da4w-gptq"],
291289
help="type of quantization",
292290
)
293291

@@ -430,7 +428,7 @@ def _prepare_for_llama_export(modelname: str, args) -> LlamaEdgeManager:
430428
if args.dtype_override is not None:
431429
dtype_override = DType[args.dtype_override]
432430
else:
433-
dtype_override = DType["fp16"] if args.quantization_mode == "int4" else None
431+
dtype_override = DType["fp16"] if args.quantization_mode in ["8da4w", "8da4w-gptq"] else None
434432

435433
# source transforms
436434
transforms = []
@@ -500,7 +498,7 @@ def _export_llama(modelname, args) -> str: # noqa: C901
500498
if args.xnnpack:
501499
# Following changes due to.
502500
# 1. We need dynamically quantized partitioner for both pt2e_quantize options
503-
# as well as "qmode int4" which is also dynamic quantizes linear layers.
501+
# as well as "qmode 8da4w" which is also dynamic quantizes linear layers.
504502
# 2. XNNPACK partitioner seems to result in seg fault for non dqlinear ops.
505503
partitioners[XnnpackDynamicallyQuantizedPartitioner.__name__] = (
506504
XnnpackDynamicallyQuantizedPartitioner()

0 commit comments

Comments
 (0)