Skip to content

Commit 2d8fa1f

Browse files
jerryzh168facebook-github-bot
authored andcommitted
Add some gptq related args to quantize function (#2577)
Summary: Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom): * __->__ #2577 Test Plan: Manully verified locally that the args passed through to quantize function `python3 -m examples.models.llama2.export_llama -c stories110M.pt -p params.json -qmode 8da4w-gptq -X -d fp32 -G 2568 --calibration_tasks wikitext fads --calibration_seq_length 1288 --calibration_limit 5123` Pull Request resolved: #2577 python3 -m examples.models.llama2.export_llama -c stories110M.pt -p params.json -qmode 8da4w-gptq -X -d fp32 -G 2568 --calibration_tasks wikitext fads --calibration_seq_length 1288 --calibration_limit 5123 Reviewed By: Jack-Khuu Differential Revision: D55250463 Pulled By: jerryzh168 fbshipit-source-id: bdf1299952c1f1010a39849bcf70f398bddfce06
1 parent ba920e4 commit 2d8fa1f

File tree

1 file changed

+28
-3
lines changed

1 file changed

+28
-3
lines changed

examples/models/llama2/export_llama_lib.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,7 @@ def quantize(
204204
activation_dtype: Optional[DType],
205205
checkpoint_path: Optional[Path] = None,
206206
# following arguments only available when setting int4 quantization.
207-
groupsize: int = 128,
207+
group_size: int = 128,
208208
# following arguments only used for GPTQ
209209
calibration_tasks: Optional[list] = None,
210210
calibration_limit: int = 5,
@@ -255,7 +255,7 @@ def quantize(
255255
tokenizer,
256256
blocksize,
257257
percdamp,
258-
groupsize,
258+
group_size,
259259
calibration_tasks,
260260
calibration_limit,
261261
calibration_seq_length,
@@ -320,6 +320,25 @@ def build_args_parser() -> argparse.ArgumentParser:
320320
default=f"{ckpt_dir}/params/demo_rand_params.pth",
321321
help="checkpoint path",
322322
)
323+
parser.add_argument(
324+
"--calibration_tasks",
325+
nargs="+",
326+
type=str,
327+
default=[],
328+
help="Tasks for GPTQ calibration",
329+
)
330+
parser.add_argument(
331+
"--calibration_limit",
332+
type=int,
333+
default=5,
334+
help="number of samples used for calibration",
335+
)
336+
parser.add_argument(
337+
"--calibration_seq_length",
338+
type=int,
339+
default=2048,
340+
help="Sequence length for GPTQ calibration",
341+
)
323342
parser.add_argument(
324343
"-t",
325344
"--tokenizer_path",
@@ -370,7 +389,9 @@ def build_args_parser() -> argparse.ArgumentParser:
370389
default=None,
371390
help="Use cProfile to profile model export. Results saved to profile_path as a html file.",
372391
)
373-
parser.add_argument("-G", "--groupsize", default=None, help="specify the groupsize")
392+
parser.add_argument(
393+
"-G", "--group_size", default=None, help="group_size for weight quantization"
394+
)
374395

375396
parser.add_argument(
376397
"-d",
@@ -487,6 +508,10 @@ def _prepare_for_llama_export(modelname: str, args) -> LlamaEdgeManager:
487508
tokenizer_path=(
488509
Path(path) if (path := args.tokenizer_path) is not None else None
489510
),
511+
group_size=args.group_size,
512+
calibration_tasks=args.calibration_tasks,
513+
calibration_limit=args.calibration_limit,
514+
calibration_seq_length=args.calibration_seq_length,
490515
)
491516
)
492517

0 commit comments

Comments
 (0)