Skip to content

Commit d13164c

Browse files
Jack-Khuufacebook-github-bot
authored andcommitted
Force callers to manually provide GPTQ args (#2795)
Summary: Default values for quant are convenient, but have caused confusion. Since there are no active users of GPTQ, we want to pre-empt potential ambiguity by making the following fields required and manually specified: * Group Size * Calibration Limit * Calibration Sequence Length --- Note: Group Size's default value is untouched since it is utilized by int4 quant and has active callers Reviewed By: mergennachin Differential Revision: D55605859
1 parent d612c23 commit d13164c

File tree

1 file changed

+22
-9
lines changed

1 file changed

+22
-9
lines changed

examples/models/llama2/export_llama_lib.py

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
from functools import partial
1717
from pathlib import Path
18-
from typing import List, Optional, Union
18+
from typing import Any, List, Optional, Union
1919

2020
import pkg_resources
2121
import torch
@@ -214,12 +214,12 @@ def quantize(
214214
qmode: str,
215215
activation_dtype: Optional[DType],
216216
checkpoint_path: Optional[Path] = None,
217-
# following arguments only available when setting int4 quantization.
218-
group_size: int = 128,
219-
# following arguments only used for GPTQ
217+
# following arguments only available when setting int4 or gptq quantization.
218+
group_size: Optional[int] = 128,
219+
# following arguments are only used for GPTQ
220220
calibration_tasks: Optional[list] = None,
221-
calibration_limit: int = 100,
222-
calibration_seq_length: int = 2048,
221+
calibration_limit: Optional[int] = None,
222+
calibration_seq_length: Optional[int] = None,
223223
pad_calibration_inputs: bool = False,
224224
percdamp: float = 0.01,
225225
blocksize: int = 128,
@@ -245,13 +245,13 @@ def quantize(
245245
# if checkpoint_path is None:
246246
# checkpoint_path = Path("checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth")
247247

248-
if calibration_tasks is None:
249-
calibration_tasks = ["wikitext"]
250-
251248
if qmode == "int8":
252249
# Add quantization mode options here: group size, bit width, etc.
253250
return WeightOnlyInt8QuantHandler(model).quantized_model()
254251
elif qmode == "8da4w":
252+
# Check for required args
253+
if group_size is None:
254+
raise Exception("For 8da4w quantization, group size must be specified.")
255255
from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer
256256

257257
model = Int8DynActInt4WeightQuantizer(
@@ -261,6 +261,19 @@ def quantize(
261261
print("quantized model:", model)
262262
return model
263263
elif qmode == "8da4w-gptq":
264+
# Check for required args
265+
required_args: Optional[Any] = [
266+
group_size,
267+
calibration_limit,
268+
calibration_seq_length,
269+
]
270+
if any(arg is None for arg in required_args):
271+
raise Exception(
272+
"For 8da4w-gptq quantization, group size, calibration limit and calibration sequence length must be specified."
273+
)
274+
if calibration_tasks is None:
275+
calibration_tasks = ["wikitext"]
276+
264277
from torchao.quantization.quant_api import Int8DynActInt4WeightGPTQQuantizer
265278

266279
if tokenizer_path is None:

0 commit comments

Comments
 (0)