Skip to content

Commit 8209b5d

Browse files
committed
revert llama_eval, create main example
1 parent 471e469 commit 8209b5d

File tree

3 files changed

+15
-20
lines changed

3 files changed

+15
-20
lines changed

examples/main/main.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ int main(int argc, char ** argv) {
144144
fprintf(stderr, "%s: testing memory usage for n_batch = %d, n_ctx = %d\n", __func__, params.n_batch, params.n_ctx);
145145

146146
const std::vector<llama_token> tmp(params.n_batch, llama_token_bos());
147-
llama_eval(ctx, tmp.data(), tmp.size(), params.n_ctx, params.n_threads, params.pp_threads);
147+
llama_eval(ctx, tmp.data(), tmp.size(), params.n_ctx, params.n_threads);
148148
}
149149

150150
llama_print_timings(ctx);
@@ -406,7 +406,7 @@ int main(int argc, char ** argv) {
406406
// do one empty run to warm up the model
407407
{
408408
const std::vector<llama_token> tmp = { llama_token_bos(), };
409-
llama_eval(ctx, tmp.data(), tmp.size(), 0, params.n_threads, params.pp_threads);
409+
llama_eval(ctx, tmp.data(), tmp.size(), 0, params.n_threads);
410410
llama_reset_timings(ctx);
411411
}
412412

@@ -513,7 +513,8 @@ int main(int argc, char ** argv) {
513513

514514
for (int i = 0; i < input_size; i += params.n_batch) {
515515
int n_eval = std::min(input_size - i, params.n_batch);
516-
if (llama_eval(ctx_guidance, input_buf + i, n_eval, n_past_guidance, params.n_threads, params.pp_threads)) {
516+
int eval_thr = n_eval > 1 ? params.pp_threads : params.n_threads;
517+
if (llama_eval(ctx_guidance, input_buf + i, n_eval, n_past_guidance, eval_thr)) {
517518
fprintf(stderr, "%s : failed to eval\n", __func__);
518519
return 1;
519520
}
@@ -527,7 +528,8 @@ int main(int argc, char ** argv) {
527528
if (n_eval > params.n_batch) {
528529
n_eval = params.n_batch;
529530
}
530-
if (llama_eval(ctx, &embd[i], n_eval, n_past, params.n_threads, params.pp_threads)) {
531+
int eval_thr = n_eval > 1 ? params.pp_threads : params.n_threads;
532+
if (llama_eval(ctx, &embd[i], n_eval, n_past, eval_thr)) {
531533
fprintf(stderr, "%s : failed to eval\n", __func__);
532534
return 1;
533535
}

llama.cpp

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1787,7 +1787,6 @@ static struct ggml_cgraph * llama_build_graph(
17871787
// - n_tokens number of tokens
17881788
// - n_past: the context size so far
17891789
// - n_threads: number of threads to use for inference
1790-
// - pp_threads: number of threads to use for prompt processing
17911790
//
17921791
static bool llama_eval_internal(
17931792
llama_context & lctx,
@@ -1796,7 +1795,6 @@ static bool llama_eval_internal(
17961795
int n_tokens,
17971796
int n_past,
17981797
int n_threads,
1799-
int pp_threads,
18001798
const char * cgraph_fname) {
18011799

18021800
LLAMA_ASSERT((!tokens && embd) || (tokens && !embd));
@@ -1840,8 +1838,7 @@ static bool llama_eval_internal(
18401838

18411839
// for big prompts, if BLAS is enabled, it is better to use only one thread
18421840
// otherwise, the threads are spin-lock waiting for the BLAS calls and are degrading the performance
1843-
pp_threads = N >= 32 && ggml_cpu_has_blas() && !ggml_cpu_has_gpublas() ? 1 : pp_threads;
1844-
n_threads = N > 1 ? pp_threads : n_threads;
1841+
n_threads = N >= 32 && ggml_cpu_has_blas() && !ggml_cpu_has_gpublas() ? 1 : n_threads;
18451842

18461843
struct ggml_tensor * res = gf->nodes[gf->n_nodes - 1];
18471844
struct ggml_tensor * embeddings = gf->nodes[gf->n_nodes - 2];
@@ -3487,7 +3484,7 @@ struct llama_context * llama_new_context_with_model(
34873484
if (ggml_mpi_rank(ctx->ctx_mpi) > 0) {
34883485
// Enter a blocking eval loop with dummy input, letting rank=0 drive the process
34893486
const std::vector<llama_token> tmp(ctx->model.hparams.n_ctx, llama_token_bos());
3490-
while (!llama_eval(ctx, tmp.data(), tmp.size(), 0, 0, 0)) {};
3487+
while (!llama_eval(ctx, tmp.data(), tmp.size(), 0, 0)) {};
34913488
llama_backend_free();
34923489
exit(1);
34933490
}
@@ -4179,9 +4176,8 @@ int llama_eval(
41794176
const llama_token * tokens,
41804177
int n_tokens,
41814178
int n_past,
4182-
int n_threads,
4183-
int pp_threads) {
4184-
if (!llama_eval_internal(*ctx, tokens, nullptr, n_tokens, n_past, n_threads, pp_threads, nullptr)) {
4179+
int n_threads) {
4180+
if (!llama_eval_internal(*ctx, tokens, nullptr, n_tokens, n_past, n_threads, nullptr)) {
41854181
LLAMA_LOG_ERROR("%s: failed to eval\n", __func__);
41864182
return 1;
41874183
}
@@ -4202,9 +4198,8 @@ int llama_eval_embd(
42024198
const float * embd,
42034199
int n_tokens,
42044200
int n_past,
4205-
int n_threads,
4206-
int pp_threads) {
4207-
if (!llama_eval_internal(*ctx, nullptr, embd, n_tokens, n_past, n_threads, pp_threads, nullptr)) {
4201+
int n_threads) {
4202+
if (!llama_eval_internal(*ctx, nullptr, embd, n_tokens, n_past, n_threads, nullptr)) {
42084203
LLAMA_LOG_ERROR("%s: failed to eval\n", __func__);
42094204
return 1;
42104205
}
@@ -4225,7 +4220,7 @@ int llama_eval_export(struct llama_context * ctx, const char * fname) {
42254220

42264221
const std::vector<llama_token> tmp(n_batch, llama_token_bos());
42274222

4228-
if (!llama_eval_internal(*ctx, tmp.data(), nullptr, tmp.size(), n_ctx, 1, 1, fname)) {
4223+
if (!llama_eval_internal(*ctx, tmp.data(), nullptr, tmp.size(), n_ctx, 1, fname)) {
42294224
LLAMA_LOG_ERROR("%s: failed to eval\n", __func__);
42304225
return 1;
42314226
}

llama.h

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -308,17 +308,15 @@ extern "C" {
308308
const llama_token * tokens,
309309
int n_tokens,
310310
int n_past,
311-
int n_threads,
312-
int pp_threads);
311+
int n_threads);
313312

314313
// Same as llama_eval, but use float matrix input directly.
315314
LLAMA_API int llama_eval_embd(
316315
struct llama_context * ctx,
317316
const float * embd,
318317
int n_tokens,
319318
int n_past,
320-
int n_threads,
321-
int pp_threads);
319+
int n_threads);
322320

323321
// Export a static computation graph for context of 511 and batch size of 1
324322
// NOTE: since this functionality is mostly for debugging and demonstration purposes, we hardcode these

0 commit comments

Comments
 (0)