@@ -168,6 +168,12 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
168
168
break ;
169
169
}
170
170
params.n_ctx = std::stoi (argv[i]);
171
+ } else if (arg == " -gqa" || arg == " --gqa" ) {
172
+ if (++i >= argc) {
173
+ invalid_param = true ;
174
+ break ;
175
+ }
176
+ params.n_gqa = std::stoi (argv[i]);
171
177
} else if (arg == " --rope-freq-base" ) {
172
178
if (++i >= argc) {
173
179
invalid_param = true ;
@@ -485,6 +491,9 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
485
491
fprintf (stdout, " -f FNAME, --file FNAME\n " );
486
492
fprintf (stdout, " prompt file to start generation.\n " );
487
493
fprintf (stdout, " -n N, --n-predict N number of tokens to predict (default: %d, -1 = infinity)\n " , params.n_predict );
494
+ fprintf (stdout, " -c N, --ctx-size N size of the prompt context (default: %d)\n " , params.n_ctx );
495
+ fprintf (stdout, " -b N, --batch-size N batch size for prompt processing (default: %d)\n " , params.n_batch );
496
+ fprintf (stdout, " -gqa N, --gqa N grouped-query attention factor (TEMP!!! use 8 for LLaMAv2 70B) (default: %d)\n " , params.n_gqa );
488
497
fprintf (stdout, " --top-k N top-k sampling (default: %d, 0 = disabled)\n " , params.top_k );
489
498
fprintf (stdout, " --top-p N top-p sampling (default: %.1f, 1.0 = disabled)\n " , (double )params.top_p );
490
499
fprintf (stdout, " --tfs N tail free sampling, parameter z (default: %.1f, 1.0 = disabled)\n " , (double )params.tfs_z );
@@ -505,15 +514,13 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
505
514
fprintf (stdout, " --cfg-negative-prompt PROMPT \n " );
506
515
fprintf (stdout, " negative prompt to use for guidance. (default: empty)\n " );
507
516
fprintf (stdout, " --cfg-scale N strength of guidance (default: %f, 1.0 = disable)\n " , params.cfg_scale );
508
- fprintf (stdout, " -c N, --ctx-size N size of the prompt context (default: %d)\n " , params.n_ctx );
509
517
fprintf (stdout, " --rope-freq-base N RoPE base frequency (default: %.1f)\n " , params.rope_freq_base );
510
518
fprintf (stdout, " --rope-freq-scale N RoPE frequency scaling factor (default: %g)\n " , params.rope_freq_scale );
511
519
fprintf (stdout, " --ignore-eos ignore end of stream token and continue generating (implies --logit-bias 2-inf)\n " );
512
520
fprintf (stdout, " --no-penalize-nl do not penalize newline token\n " );
513
521
fprintf (stdout, " --memory-f32 use f32 instead of f16 for memory key+value (default: disabled)\n " );
514
522
fprintf (stdout, " not recommended: doubles context memory required and no measurable increase in quality\n " );
515
523
fprintf (stdout, " --temp N temperature (default: %.1f)\n " , (double )params.temp );
516
- fprintf (stdout, " -b N, --batch-size N batch size for prompt processing (default: %d)\n " , params.n_batch );
517
524
fprintf (stdout, " --perplexity compute perplexity over each ctx window of the prompt\n " );
518
525
fprintf (stdout, " --perplexity-lines compute perplexity over each line of the prompt\n " );
519
526
fprintf (stdout, " --keep number of tokens to keep from the initial prompt (default: %d, -1 = all)\n " , params.n_keep );
@@ -580,6 +587,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
580
587
581
588
lparams.n_ctx = params.n_ctx ;
582
589
lparams.n_batch = params.n_batch ;
590
+ lparams.n_gqa = params.n_gqa ;
583
591
lparams.n_gpu_layers = params.n_gpu_layers ;
584
592
lparams.main_gpu = params.main_gpu ;
585
593
lparams.tensor_split = params.tensor_split ;
0 commit comments