Skip to content

Commit d5512b7

Browse files
authored
server: add rms_norm_eps parameter (#2380)
1 parent c798308 commit d5512b7

File tree

1 file changed

+9
-0
lines changed

1 file changed

+9
-0
lines changed

examples/server/server.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -609,6 +609,7 @@ static void server_print_usage(const char *argv0, const gpt_params &params,
609609
fprintf(stdout, " -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads);
610610
fprintf(stdout, " -c N, --ctx-size N size of the prompt context (default: %d)\n", params.n_ctx);
611611
fprintf(stdout, " -gqa N, --gqa N grouped-query attention factor (TEMP!!! use 8 for LLaMAv2 70B) (default: %d)\n", params.n_gqa);
612+
fprintf(stdout, " -eps N, --rms-norm-eps N rms norm eps (TEMP!!! use 1e-5 for LLaMAv2) (default: %.1e)\n", params.rms_norm_eps);
612613
fprintf(stdout, " --rope-freq-base N RoPE base frequency (default: %.1f)\n", params.rope_freq_base);
613614
fprintf(stdout, " --rope-freq-scale N RoPE frequency scaling factor (default: %g)\n", params.rope_freq_scale);
614615
fprintf(stdout, " -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch);
@@ -734,6 +735,14 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
734735
}
735736
params.n_gqa = std::stoi(argv[i]);
736737
}
738+
else if (arg == "-eps" || arg == "--rms-norm-eps") {
739+
if (++i >= argc)
740+
{
741+
invalid_param = true;
742+
break;
743+
}
744+
params.rms_norm_eps = std::stof(argv[i]);
745+
}
737746
else if (arg == "--rope-freq-base")
738747
{
739748
if (++i >= argc)

0 commit comments

Comments
 (0)