@@ -608,6 +608,7 @@ static void server_print_usage(const char *argv0, const gpt_params ¶ms,
608
608
fprintf (stderr, " -v, --verbose verbose output (default: %s)\n " , server_verbose ? " enabled" : " disabled" );
609
609
fprintf (stderr, " -t N, --threads N number of threads to use during computation (default: %d)\n " , params.n_threads );
610
610
fprintf (stderr, " -c N, --ctx-size N size of the prompt context (default: %d)\n " , params.n_ctx );
611
+ fprintf (stdout, " -gqa N, --gqa N grouped-query attention factor (TEMP!!! use 8 for LLaMAv2 70B) (default: %d)\n " , params.n_gqa );
611
612
fprintf (stderr, " --rope-freq-base N RoPE base frequency (default: %.1f)\n " , params.rope_freq_base );
612
613
fprintf (stderr, " --rope-freq-scale N RoPE frequency scaling factor (default: %g)\n " , params.rope_freq_scale );
613
614
fprintf (stderr, " -b N, --batch-size N batch size for prompt processing (default: %d)\n " , params.n_batch );
@@ -724,17 +725,28 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
724
725
}
725
726
params.n_ctx = std::stoi (argv[i]);
726
727
}
728
+ else if (arg == " -gqa" || arg == " --gqa" )
729
+ {
730
+ if (++i >= argc)
731
+ {
732
+ invalid_param = true ;
733
+ break ;
734
+ }
735
+ params.n_gqa = std::stoi (argv[i]);
736
+ }
727
737
else if (arg == " --rope-freq-base" )
728
738
{
729
- if (++i >= argc) {
739
+ if (++i >= argc)
740
+ {
730
741
invalid_param = true ;
731
742
break ;
732
743
}
733
744
params.rope_freq_base = std::stof (argv[i]);
734
745
}
735
746
else if (arg == " --rope-freq-scale" )
736
747
{
737
- if (++i >= argc) {
748
+ if (++i >= argc)
749
+ {
738
750
invalid_param = true ;
739
751
break ;
740
752
}
0 commit comments