Skip to content

Commit 6bef9e7

Browse files
andrewor14facebook-github-bot
authored andcommitted
Rename int4 to 8da4w in llama2 quantization (#2573)
Summary: Pull Request resolved: #2573 int4 has been confused with "int4 weight only" before, when in reality it is "int4 weights + int8 dynamic activations". Renaming it to "8da4w" will reduce confusion and make it more consistent with "8da4w-gptq". #accept2land Reviewed By: jerryzh168 Differential Revision: D55215146 fbshipit-source-id: 435c9b3e70e2546c8e0afc2df848546d7eb2d208
1 parent 60eb1bb commit 6bef9e7

File tree

2 files changed

+12
-10
lines changed

2 files changed

+12
-10
lines changed

examples/models/llama2/export_llama_lib.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,7 @@ def quantize(
222222
Quantizes a model by converting all weights to int8.
223223
Args:
224224
model: A model to quantize.
225-
qmode: quantization mode, e.g. int8, int4
225+
qmode: quantization mode, e.g. int8, 8da4w, 8da4w-gptq
226226
Returns:
227227
A quantized model.
228228
"""
@@ -240,13 +240,13 @@ def quantize(
240240
if qmode == "int8":
241241
# Add quantization mode options here: group size, bit width, etc.
242242
return WeightOnlyInt8QuantHandler(model).quantized_model()
243-
elif qmode == "int4":
244-
model_int4 = Int8DynActInt4WeightQuantHandler(
243+
elif qmode == "8da4w":
244+
model = Int8DynActInt4WeightQuantHandler(
245245
model,
246246
precision=torch_dtype,
247247
).quantized_model()
248-
print("quantized model:", model_int4)
249-
return model_int4
248+
print("quantized model:", model)
249+
return model
250250
elif qmode == "8da4w-gptq":
251251
from torchao.quantization.quant_api import Int8DynActInt4WeightGPTQQuantizer
252252

@@ -315,7 +315,7 @@ def build_args_parser() -> argparse.ArgumentParser:
315315
"--quantization_mode",
316316
type=str,
317317
default=None,
318-
choices=["int8", "int4", "8da4w-gptq"],
318+
choices=["int8", "8da4w", "8da4w-gptq"],
319319
help="type of quantization",
320320
)
321321

@@ -472,8 +472,10 @@ def _prepare_for_llama_export(modelname: str, args) -> LlamaEdgeManager:
472472
# dtype override
473473
if args.dtype_override is not None:
474474
dtype_override = DType[args.dtype_override]
475+
elif args.quantization_mode in ["8da4w", "8da4w-gptq"]:
476+
dtype_override = DType["fp16"]
475477
else:
476-
dtype_override = DType["fp16"] if args.quantization_mode == "int4" else None
478+
dtype_override = None
477479

478480
# source transforms
479481
transforms = []
@@ -547,7 +549,7 @@ def _export_llama(modelname, args) -> str: # noqa: C901
547549
if args.xnnpack:
548550
# Following changes due to.
549551
# 1. We need dynamically quantized partitioner for both pt2e_quantize options
550-
# as well as "qmode int4" which is also dynamic quantizes linear layers.
552+
# as well as "qmode 8da4w" which is also dynamic quantizes linear layers.
551553
# 2. XNNPACK partitioner seems to result in seg fault for non dqlinear ops.
552554
partitioners[XnnpackDynamicallyQuantizedPartitioner.__name__] = (
553555
XnnpackDynamicallyQuantizedPartitioner()

examples/models/llama2/model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -144,8 +144,8 @@ def __init__(self, **kwargs):
144144

145145
simple_quantizer = WeightOnlyInt8QuantHandler(self.model_)
146146
self.model_ = simple_quantizer.convert_for_runtime()
147-
elif "int4" in str(checkpoint_path):
148-
print("Using int4 weight-only quantization!")
147+
elif "8da4w" in str(checkpoint_path):
148+
print("Using int4 weight and int8 dynamic activation quantization!")
149149
from .quantize import Int8DynActInt4WeightQuantHandler
150150

151151
simple_quantizer = Int8DynActInt4WeightQuantHandler(self.model_)

0 commit comments

Comments
 (0)