Skip to content

Commit be26777

Browse files
committed
add pp_threads support to other files
1 parent d854348 commit be26777

File tree

3 files changed

+20
-7
lines changed

3 files changed

+20
-7
lines changed

examples/embedding/embedding.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,8 @@ int main(int argc, char ** argv) {
5050
// print system information
5151
{
5252
fprintf(stderr, "\n");
53-
fprintf(stderr, "system_info: n_threads = %d / %d | %s\n",
54-
params.n_threads, std::thread::hardware_concurrency(), llama_print_system_info());
53+
fprintf(stderr, "system_info: n_threads = %d / %d | pp_threads = %d / %d | %s\n",
54+
params.n_threads, std::thread::hardware_concurrency(), params.pp_threads, std::thread::hardware_concurrency(), llama_print_system_info());
5555
}
5656

5757
int n_past = 0;
@@ -74,7 +74,7 @@ int main(int argc, char ** argv) {
7474

7575
if (params.embedding){
7676
if (embd_inp.size() > 0) {
77-
if (llama_eval(ctx, embd_inp.data(), embd_inp.size(), n_past, params.n_threads, params.n_threads)) {
77+
if (llama_eval(ctx, embd_inp.data(), embd_inp.size(), n_past, params.n_threads, params.pp_threads)) {
7878
fprintf(stderr, "%s : failed to eval\n", __func__);
7979
return 1;
8080
}

examples/save-load-state/save-load-state.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ int main(int argc, char ** argv) {
1010
gpt_params params;
1111
params.seed = 42;
1212
params.n_threads = 4;
13+
params.pp_threads = 4;
1314
params.repeat_last_n = 64;
1415
params.prompt = "The quick brown fox";
1516

@@ -56,7 +57,7 @@ int main(int argc, char ** argv) {
5657
}
5758

5859
// evaluate prompt
59-
llama_eval(ctx, tokens.data(), n_prompt_tokens, n_past, params.n_threads, params.n_threads);
60+
llama_eval(ctx, tokens.data(), n_prompt_tokens, n_past, params.n_threads, params.pp_threads);
6061

6162
last_n_tokens_data.insert(last_n_tokens_data.end(), tokens.data(), tokens.data() + n_prompt_tokens);
6263
n_past += n_prompt_tokens;
@@ -93,7 +94,7 @@ int main(int argc, char ** argv) {
9394
last_n_tokens_data.push_back(next_token);
9495

9596
printf("%s", next_token_str);
96-
if (llama_eval(ctx, &next_token, 1, n_past, params.n_threads, params.n_threads)) {
97+
if (llama_eval(ctx, &next_token, 1, n_past, params.n_threads, params.pp_threads)) {
9798
fprintf(stderr, "\n%s : failed to evaluate\n", __func__);
9899
llama_free(ctx);
99100
llama_free_model(model);
@@ -153,7 +154,7 @@ int main(int argc, char ** argv) {
153154
last_n_tokens_data.push_back(next_token);
154155

155156
printf("%s", next_token_str);
156-
if (llama_eval(ctx2, &next_token, 1, n_past, params.n_threads, params.n_threads)) {
157+
if (llama_eval(ctx2, &next_token, 1, n_past, params.n_threads, params.pp_threads)) {
157158
fprintf(stderr, "\n%s : failed to evaluate\n", __func__);
158159
llama_free(ctx2);
159160
llama_free_model(model);

examples/server/server.cpp

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -382,7 +382,7 @@ struct llama_server_context
382382
{
383383
n_eval = params.n_batch;
384384
}
385-
if (llama_eval(ctx, &embd[n_past], n_eval, n_past, params.n_threads, params.n_threads))
385+
if (llama_eval(ctx, &embd[n_past], n_eval, n_past, params.n_threads, params.pp_threads))
386386
{
387387
LOG_ERROR("failed to eval", {
388388
{"n_eval", n_eval},
@@ -648,6 +648,8 @@ static void server_print_usage(const char *argv0, const gpt_params &params,
648648
fprintf(stdout, " -h, --help show this help message and exit\n");
649649
fprintf(stdout, " -v, --verbose verbose output (default: %s)\n", server_verbose ? "enabled" : "disabled");
650650
fprintf(stdout, " -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads);
651+
fprintf(stdout, " -ppt N, --pp-threads N\n");
652+
fprintf(stdout, " number of threads to use during prompt processing (default: %d)\n", params.pp_threads);
651653
fprintf(stdout, " -c N, --ctx-size N size of the prompt context (default: %d)\n", params.n_ctx);
652654
fprintf(stdout, " -gqa N, --gqa N grouped-query attention factor (TEMP!!! use 8 for LLaMAv2 70B) (default: %d)\n", params.n_gqa);
653655
fprintf(stdout, " -eps N, --rms-norm-eps N rms norm eps (TEMP!!! use 1e-5 for LLaMAv2) (default: %.1e)\n", params.rms_norm_eps);
@@ -818,6 +820,15 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
818820
}
819821
params.n_threads = std::stoi(argv[i]);
820822
}
823+
else if (arg == "-ppt" || arg == "--pp-threads")
824+
{
825+
if (++i >= argc)
826+
{
827+
invalid_param = true;
828+
break;
829+
}
830+
params.pp_threads = std::stoi(argv[i]);
831+
}
821832
else if (arg == "-b" || arg == "--batch-size")
822833
{
823834
if (++i >= argc)
@@ -1178,6 +1189,7 @@ int main(int argc, char **argv)
11781189
{"commit", BUILD_COMMIT}});
11791190
LOG_INFO("system info", {
11801191
{"n_threads", params.n_threads},
1192+
{"pp_threads", params.pp_threads},
11811193
{"total_threads", std::thread::hardware_concurrency()},
11821194
{"system_info", llama_print_system_info()},
11831195
});

0 commit comments

Comments
 (0)