@@ -204,7 +204,7 @@ def quantize(
204
204
activation_dtype : Optional [DType ],
205
205
checkpoint_path : Optional [Path ] = None ,
206
206
# following arguments only available when setting int4 quantization.
207
- groupsize : int = 128 ,
207
+ group_size : int = 128 ,
208
208
# following arguments only used for GPTQ
209
209
calibration_tasks : Optional [list ] = None ,
210
210
calibration_limit : int = 5 ,
@@ -255,7 +255,7 @@ def quantize(
255
255
tokenizer ,
256
256
blocksize ,
257
257
percdamp ,
258
- groupsize ,
258
+ group_size ,
259
259
calibration_tasks ,
260
260
calibration_limit ,
261
261
calibration_seq_length ,
@@ -320,6 +320,25 @@ def build_args_parser() -> argparse.ArgumentParser:
320
320
default = f"{ ckpt_dir } /params/demo_rand_params.pth" ,
321
321
help = "checkpoint path" ,
322
322
)
323
+ parser .add_argument (
324
+ "--calibration_tasks" ,
325
+ nargs = "+" ,
326
+ type = str ,
327
+ default = [],
328
+ help = "Tasks for GPTQ calibration" ,
329
+ )
330
+ parser .add_argument (
331
+ "--calibration_limit" ,
332
+ type = int ,
333
+ default = 5 ,
334
+ help = "number of samples used for calibration" ,
335
+ )
336
+ parser .add_argument (
337
+ "--calibration_seq_length" ,
338
+ type = int ,
339
+ default = 2048 ,
340
+ help = "Sequence length for GPTQ calibration" ,
341
+ )
323
342
parser .add_argument (
324
343
"-t" ,
325
344
"--tokenizer_path" ,
@@ -370,7 +389,9 @@ def build_args_parser() -> argparse.ArgumentParser:
370
389
default = None ,
371
390
help = "Use cProfile to profile model export. Results saved to profile_path as a html file." ,
372
391
)
373
- parser .add_argument ("-G" , "--groupsize" , default = None , help = "specify the groupsize" )
392
+ parser .add_argument (
393
+ "-G" , "--group_size" , default = None , help = "group_size for weight quantization"
394
+ )
374
395
375
396
parser .add_argument (
376
397
"-d" ,
@@ -487,6 +508,10 @@ def _prepare_for_llama_export(modelname: str, args) -> LlamaEdgeManager:
487
508
tokenizer_path = (
488
509
Path (path ) if (path := args .tokenizer_path ) is not None else None
489
510
),
511
+ group_size = args .group_size ,
512
+ calibration_tasks = args .calibration_tasks ,
513
+ calibration_limit = args .calibration_limit ,
514
+ calibration_seq_length = args .calibration_seq_length ,
490
515
)
491
516
)
492
517
0 commit comments