Skip to content

Commit 2a3e6ab

Browse files
Jack-Khuufacebook-github-bot
authored andcommitted
Defer resolution of the default value of arguments used by quantize (#2738)
Summary: Quantize() (specifically GPTQ) is the sole user of the many params, but default values are introduced early and in multiple places. This is bug prone and confusing. * For example, previously the default value of calibration tasks was [], which is not something `Int8DynActInt4WeightGPTQQuantizer` handles gracefully. This diff defers default value resolution to quantize() since that is the direct call that uses them. Reviewed By: jerryzh168 Differential Revision: D55458866
1 parent 01e259c commit 2a3e6ab

File tree

1 file changed

+21
-9
lines changed

1 file changed

+21
-9
lines changed

examples/models/llama2/export_llama_lib.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -218,8 +218,8 @@ def quantize(
218218
group_size: int = 128,
219219
# following arguments only used for GPTQ
220220
calibration_tasks: Optional[list] = None,
221-
calibration_limit: int = 5,
222-
calibration_seq_length: int = 100,
221+
calibration_limit: int = 100,
222+
calibration_seq_length: int = 2048,
223223
pad_calibration_inputs: bool = False,
224224
percdamp: float = 0.01,
225225
blocksize: int = 128,
@@ -342,19 +342,19 @@ def build_args_parser() -> argparse.ArgumentParser:
342342
"--calibration_tasks",
343343
nargs="+",
344344
type=str,
345-
default=[],
345+
default=None,
346346
help="Tasks for GPTQ calibration",
347347
)
348348
parser.add_argument(
349349
"--calibration_limit",
350350
type=int,
351-
default=5,
351+
default=None,
352352
help="number of samples used for calibration",
353353
)
354354
parser.add_argument(
355355
"--calibration_seq_length",
356356
type=int,
357-
default=2048,
357+
default=None,
358358
help="Sequence length for GPTQ calibration",
359359
)
360360
parser.add_argument(
@@ -531,9 +531,25 @@ def _prepare_for_llama_export(modelname: str, args) -> LlamaEdgeManager:
531531
transforms = []
532532
if args.quantization_mode:
533533
modelname = f"{modelname}_q"
534+
535+
# If these optional args are None, don't provide them to quantize()
536+
quant_args_str = [
537+
"group_size",
538+
"calibration_tasks",
539+
"calibration_limit",
540+
"calibration_seq_length",
541+
]
542+
arg_dict = vars(args)
543+
quant_args = {
544+
param: val
545+
for param in quant_args_str
546+
if (val := arg_dict.get(param)) is not None
547+
}
548+
534549
transforms.append(
535550
partial(
536551
quantize,
552+
**quant_args,
537553
qmode=args.quantization_mode,
538554
activation_dtype=dtype_override,
539555
checkpoint_path=(
@@ -542,10 +558,6 @@ def _prepare_for_llama_export(modelname: str, args) -> LlamaEdgeManager:
542558
tokenizer_path=(
543559
Path(path) if (path := args.tokenizer_path) is not None else None
544560
),
545-
group_size=args.group_size,
546-
calibration_tasks=args.calibration_tasks,
547-
calibration_limit=args.calibration_limit,
548-
calibration_seq_length=args.calibration_seq_length,
549561
)
550562
)
551563

0 commit comments

Comments
 (0)