Skip to content

Commit 04a9ef8

Browse files
committed
llama : cont
ggml-ci
1 parent 5b7cc53 commit 04a9ef8

File tree

19 files changed

+126
-78
lines changed

19 files changed

+126
-78
lines changed

examples/batched-bench/batched-bench.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,8 @@ int main(int argc, char ** argv) {
5757
return 1;
5858
}
5959

60+
llama_kv_cache * kv = llama_get_kv_cache(ctx);
61+
6062
const int32_t n_kv_max = llama_n_ctx(ctx);
6163

6264
llama_batch batch = llama_batch_init(n_kv_max, 0, 1);
@@ -132,7 +134,7 @@ int main(int argc, char ** argv) {
132134

133135
const auto t_pp_start = ggml_time_us();
134136

135-
llama_kv_cache_clear(ctx);
137+
llama_kv_cache_clear(kv);
136138

137139
if (!decode_helper(ctx, batch, ctx_params.n_batch)) {
138140
LOG_ERR("%s: llama_decode() failed\n", __func__);
@@ -141,7 +143,7 @@ int main(int argc, char ** argv) {
141143

142144
if (is_pp_shared) {
143145
for (int32_t i = 1; i < pl; ++i) {
144-
llama_kv_cache_seq_cp(ctx, 0, i, -1, -1);
146+
llama_kv_cache_seq_cp(kv, 0, i, -1, -1);
145147
}
146148
}
147149

examples/cvector-generator/cvector-generator.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -342,7 +342,8 @@ static bool cb_eval(struct ggml_tensor * t, bool ask, void * user_data) {
342342
}
343343

344344
static bool get_hidden_layers(llama_context * ctx, std::vector<llama_token> & tokens) {
345-
llama_kv_cache_clear(ctx);
345+
llama_kv_cache * kv = llama_get_kv_cache(ctx);
346+
llama_kv_cache_clear(kv);
346347
if (llama_decode(ctx, llama_batch_get_one(tokens.data(), tokens.size()))) {
347348
fprintf(stderr, "%s : failed to eval\n", __func__);
348349
return false;

examples/gritlm/gritlm.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve
1313
const llama_model * model = llama_get_model(ctx);
1414
const llama_vocab * vocab = llama_model_get_vocab(model);
1515

16+
llama_kv_cache * kv = llama_get_kv_cache(ctx);
17+
1618
llama_batch batch = llama_batch_init(llama_n_batch(ctx), 0, 1);
1719

1820
for (uint64_t i = 0; i < sentences.size(); i++) {
@@ -45,7 +47,7 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve
4547
}
4648

4749
// clear previous kv_cache values (irrelevant for embeddings)
48-
llama_kv_cache_clear(ctx);
50+
llama_kv_cache_clear(kv);
4951
llama_set_embeddings(ctx, true);
5052
llama_set_causal_attn(ctx, false);
5153

@@ -100,9 +102,11 @@ static std::string generate(llama_context * ctx, llama_sampler * smpl, const std
100102
const llama_model * model = llama_get_model(ctx);
101103
const llama_vocab * vocab = llama_model_get_vocab(model);
102104

105+
llama_kv_cache * kv = llama_get_kv_cache(ctx);
106+
103107
llama_token eos_token = llama_vocab_eos(vocab);
104108

105-
llama_kv_cache_clear(ctx);
109+
llama_kv_cache_clear(kv);
106110
llama_set_embeddings(ctx, false);
107111
llama_set_causal_attn(ctx, true);
108112

examples/imatrix/imatrix.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -431,6 +431,8 @@ static bool compute_imatrix(llama_context * ctx, const common_params & params) {
431431
const llama_model * model = llama_get_model(ctx);
432432
const llama_vocab * vocab = llama_model_get_vocab(model);
433433

434+
llama_kv_cache * kv = llama_get_kv_cache(ctx);
435+
434436
const bool add_bos = llama_vocab_get_add_bos(vocab);
435437
const int n_ctx = llama_n_ctx(ctx);
436438

@@ -497,7 +499,7 @@ static bool compute_imatrix(llama_context * ctx, const common_params & params) {
497499
const auto t_start = std::chrono::high_resolution_clock::now();
498500

499501
// clear the KV cache
500-
llama_kv_cache_clear(ctx);
502+
llama_kv_cache_clear(kv);
501503

502504
llama_batch batch = llama_batch_init(n_batch, 0, 1);
503505

examples/infill/infill.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,8 @@ int main(int argc, char ** argv) {
139139
return 1;
140140
}
141141

142+
llama_kv_cache * kv = llama_get_kv_cache(ctx);
143+
142144
const llama_vocab * vocab = llama_model_get_vocab(model);
143145

144146
const int n_ctx_train = llama_model_n_ctx_train(model);
@@ -332,8 +334,8 @@ int main(int argc, char ** argv) {
332334
LOG_DBG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n",
333335
n_past, n_left, n_ctx, params.n_keep, n_discard);
334336

335-
llama_kv_cache_seq_rm (ctx, 0, params.n_keep + 1 , params.n_keep + n_discard + 1);
336-
llama_kv_cache_seq_add(ctx, 0, params.n_keep + 1 + n_discard, n_past, -n_discard);
337+
llama_kv_cache_seq_rm (kv, 0, params.n_keep + 1 , params.n_keep + n_discard + 1);
338+
llama_kv_cache_seq_add(kv, 0, params.n_keep + 1 + n_discard, n_past, -n_discard);
337339

338340
n_past -= n_discard;
339341

examples/llama-bench/llama-bench.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1546,9 +1546,11 @@ int main(int argc, char ** argv) {
15461546
return 1;
15471547
}
15481548

1549+
llama_kv_cache * kv = llama_get_kv_cache(ctx);
1550+
15491551
test t(inst, lmodel, ctx);
15501552

1551-
llama_kv_cache_clear(ctx);
1553+
llama_kv_cache_clear(kv);
15521554

15531555
// cool off before the test
15541556
if (params.delay) {
@@ -1588,7 +1590,7 @@ int main(int argc, char ** argv) {
15881590
}
15891591

15901592
for (int i = 0; i < params.reps; i++) {
1591-
llama_kv_cache_clear(ctx);
1593+
llama_kv_cache_clear(kv);
15921594

15931595
uint64_t t_start = get_time_ns();
15941596

examples/lookahead/lookahead.cpp

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ int main(int argc, char ** argv) {
6060

6161
llama_model * model = llama_init.model.get();
6262
llama_context * ctx = llama_init.context.get();
63+
llama_kv_cache * kv = llama_get_kv_cache(ctx);
6364

6465
const llama_vocab * vocab = llama_model_get_vocab(model);
6566

@@ -95,7 +96,7 @@ int main(int argc, char ** argv) {
9596
llama_decode(ctx, llama_batch_get_one(&inp.back(), 1));
9697

9798
for (int s = 1; s < W + G + 1; ++s) {
98-
llama_kv_cache_seq_cp(ctx, 0, s, -1, -1);
99+
llama_kv_cache_seq_cp(kv, 0, s, -1, -1);
99100
}
100101

101102
const auto t_enc_end = ggml_time_us();
@@ -437,17 +438,17 @@ int main(int argc, char ** argv) {
437438

438439
// KV cache management
439440
// if no verification token matched, we simply remove all cells from this batch -> no fragmentation
440-
llama_kv_cache_seq_rm(ctx, -1, n_past, -1);
441+
llama_kv_cache_seq_rm(kv, -1, n_past, -1);
441442

442443
if (seq_id_best != 0) {
443444
// if a verification token matched, we keep the best sequence and remove the rest
444445
// this leads to some KV cache fragmentation
445-
llama_kv_cache_seq_keep(ctx, seq_id_best);
446-
llama_kv_cache_seq_cp (ctx, seq_id_best, 0, -1, -1);
447-
llama_kv_cache_seq_rm (ctx, seq_id_best, -1, -1);
446+
llama_kv_cache_seq_keep(kv, seq_id_best);
447+
llama_kv_cache_seq_cp (kv, seq_id_best, 0, -1, -1);
448+
llama_kv_cache_seq_rm (kv, seq_id_best, -1, -1);
448449

449450
for (int s = 1; s < W + G + 1; ++s) {
450-
llama_kv_cache_seq_cp(ctx, 0, s, -1, -1);
451+
llama_kv_cache_seq_cp(kv, 0, s, -1, -1);
451452
}
452453
}
453454
}

examples/lookup/lookup.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ int main(int argc, char ** argv){
3535

3636
llama_model * model = llama_init.model.get();
3737
llama_context * ctx = llama_init.context.get();
38+
llama_kv_cache * kv = llama_get_kv_cache(ctx);
3839

3940
const llama_vocab * vocab = llama_model_get_vocab(model);
4041

@@ -192,7 +193,7 @@ int main(int argc, char ** argv){
192193

193194
// KV cache management
194195
// clean the cache of draft tokens that weren't accepted
195-
llama_kv_cache_seq_rm(ctx, 0, n_past, -1);
196+
llama_kv_cache_seq_rm(kv, 0, n_past, -1);
196197

197198
common_batch_clear(batch_tgt);
198199
common_batch_add(batch_tgt, draft[0], n_past, { 0 }, true);

examples/main/main.cpp

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,8 @@ int main(int argc, char ** argv) {
162162
return 1;
163163
}
164164

165+
llama_kv_cache * kv = llama_get_kv_cache(ctx);
166+
165167
const llama_vocab * vocab = llama_model_get_vocab(model);
166168

167169
LOG_INF("%s: llama threadpool init, n_threads = %d\n", __func__, (int) params.cpuparams.n_threads);
@@ -306,7 +308,7 @@ int main(int argc, char ** argv) {
306308
}
307309

308310
// remove any "future" tokens that we might have inherited from the previous session
309-
llama_kv_cache_seq_rm(ctx, -1, n_matching_session_tokens, -1);
311+
llama_kv_cache_seq_rm(kv, -1, n_matching_session_tokens, -1);
310312
}
311313

312314
LOG_DBG("recalculate the cached logits (check): embd_inp.size() %zu, n_matching_session_tokens %zu, embd_inp.size() %zu, session_tokens.size() %zu\n",
@@ -543,8 +545,8 @@ int main(int argc, char ** argv) {
543545
LOG_DBG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n",
544546
n_past, n_left, n_ctx, params.n_keep, n_discard);
545547

546-
llama_kv_cache_seq_rm (ctx, 0, params.n_keep , params.n_keep + n_discard);
547-
llama_kv_cache_seq_add(ctx, 0, params.n_keep + n_discard, n_past, -n_discard);
548+
llama_kv_cache_seq_rm (kv, 0, params.n_keep , params.n_keep + n_discard);
549+
llama_kv_cache_seq_add(kv, 0, params.n_keep + n_discard, n_past, -n_discard);
548550

549551
n_past -= n_discard;
550552

@@ -567,9 +569,9 @@ int main(int argc, char ** argv) {
567569
LOG_DBG("div: [%6d, %6d] / %6d -> [%6d, %6d]\n", ga_i + ib*bd, ga_i + ib*bd + ga_w, ga_n, (ga_i + ib*bd)/ga_n, (ga_i + ib*bd + ga_w)/ga_n);
568570
LOG_DBG("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", ga_i + ib*bd + ga_w, n_past + ib*bd, dd, ga_i + ib*bd + ga_w + dd, n_past + ib*bd + dd);
569571

570-
llama_kv_cache_seq_add(ctx, 0, ga_i, n_past, ib*bd);
571-
llama_kv_cache_seq_div(ctx, 0, ga_i + ib*bd, ga_i + ib*bd + ga_w, ga_n);
572-
llama_kv_cache_seq_add(ctx, 0, ga_i + ib*bd + ga_w, n_past + ib*bd, dd);
572+
llama_kv_cache_seq_add(kv, 0, ga_i, n_past, ib*bd);
573+
llama_kv_cache_seq_div(kv, 0, ga_i + ib*bd, ga_i + ib*bd + ga_w, ga_n);
574+
llama_kv_cache_seq_add(kv, 0, ga_i + ib*bd + ga_w, n_past + ib*bd, dd);
573575

574576
n_past -= bd;
575577

examples/parallel/parallel.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ int main(int argc, char ** argv) {
134134

135135
llama_model * model = llama_init.model.get();
136136
llama_context * ctx = llama_init.context.get();
137+
llama_kv_cache * kv = llama_get_kv_cache(ctx);
137138

138139
const llama_vocab * vocab = llama_model_get_vocab(model);
139140

@@ -201,7 +202,7 @@ int main(int argc, char ** argv) {
201202

202203
// assign the system KV cache to all parallel sequences
203204
for (int32_t i = 1; i <= n_clients; ++i) {
204-
llama_kv_cache_seq_cp(ctx, 0, i, -1, -1);
205+
llama_kv_cache_seq_cp(kv, 0, i, -1, -1);
205206
}
206207

207208
LOG_INF("\n");
@@ -233,9 +234,9 @@ int main(int argc, char ** argv) {
233234
if (batch.n_tokens == 0) {
234235
// all sequences have ended - clear the entire KV cache
235236
for (int i = 1; i <= n_clients; ++i) {
236-
llama_kv_cache_seq_rm(ctx, i, -1, -1);
237+
llama_kv_cache_seq_rm(kv, i, -1, -1);
237238
// but keep the system prompt
238-
llama_kv_cache_seq_cp(ctx, 0, i, -1, -1);
239+
llama_kv_cache_seq_cp(kv, 0, i, -1, -1);
239240
}
240241

241242
LOG_INF("%s: clearing the KV cache\n", __func__);
@@ -371,8 +372,8 @@ int main(int argc, char ** argv) {
371372
}
372373

373374
// delete only the generated part of the sequence, i.e. keep the system prompt in the cache
374-
llama_kv_cache_seq_rm(ctx, client.id + 1, -1, -1);
375-
llama_kv_cache_seq_cp(ctx, 0, client.id + 1, -1, -1);
375+
llama_kv_cache_seq_rm(kv, client.id + 1, -1, -1);
376+
llama_kv_cache_seq_cp(kv, 0, client.id + 1, -1, -1);
376377

377378
const auto t_main_end = ggml_time_us();
378379

examples/passkey/passkey.cpp

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,8 @@ int main(int argc, char ** argv) {
8686
return 1;
8787
}
8888

89+
llama_kv_cache * kv = llama_get_kv_cache(ctx);
90+
8991
auto sparams = llama_sampler_chain_default_params();
9092

9193
llama_sampler * smpl = llama_sampler_chain_init(sparams);
@@ -132,11 +134,11 @@ int main(int argc, char ** argv) {
132134
const int ib = i/n_batch - 1;
133135
const int bd = n_batch_grp*(n_grp - 1);
134136

135-
llama_kv_cache_seq_add (ctx, 0, n_past - n_batch, n_past, ib*bd);
136-
llama_kv_cache_seq_div (ctx, 0, n_past - n_batch + ib*bd, n_past + ib*bd, n_grp);
137-
llama_kv_cache_update (ctx);
137+
llama_kv_cache_seq_add(kv, 0, n_past - n_batch, n_past, ib*bd);
138+
llama_kv_cache_seq_div(kv, 0, n_past - n_batch + ib*bd, n_past + ib*bd, n_grp);
139+
llama_update_kv_cache (ctx, kv);
138140

139-
n_past = llama_kv_cache_seq_pos_max(ctx, 0) + 1;
141+
n_past = llama_kv_cache_seq_pos_max(kv, 0) + 1;
140142
}
141143

142144
common_batch_clear(batch);
@@ -166,12 +168,12 @@ int main(int argc, char ** argv) {
166168

167169
LOG_INF("%s: shifting KV cache with %d\n", __func__, n_discard);
168170

169-
llama_kv_cache_seq_rm (ctx, 0, n_keep , n_keep + n_discard);
170-
llama_kv_cache_seq_add(ctx, 0, n_keep + n_discard, n_ctx, -n_discard);
171-
//llama_kv_cache_defrag (ctx);
172-
llama_kv_cache_update (ctx);
171+
llama_kv_cache_seq_rm (kv, 0, n_keep , n_keep + n_discard);
172+
llama_kv_cache_seq_add(kv, 0, n_keep + n_discard, n_ctx, -n_discard);
173+
//llama_kv_cache_defrag (kv);
174+
llama_update_kv_cache (ctx, kv);
173175

174-
n_past = llama_kv_cache_seq_pos_max(ctx, 0) + 1;
176+
n_past = llama_kv_cache_seq_pos_max(kv, 0) + 1;
175177

176178
common_batch_clear(batch);
177179

@@ -197,12 +199,12 @@ int main(int argc, char ** argv) {
197199
if (n_discard > 0) {
198200
LOG_INF("%s: shifting KV cache with %d to free space for the answer\n", __func__, n_discard);
199201

200-
llama_kv_cache_seq_rm (ctx, 0, n_keep , n_keep + n_discard);
201-
llama_kv_cache_seq_add(ctx, 0, n_keep + n_discard, n_ctx, -n_discard);
202-
//llama_kv_cache_defrag (ctx);
203-
llama_kv_cache_update (ctx);
202+
llama_kv_cache_seq_rm (kv, 0, n_keep , n_keep + n_discard);
203+
llama_kv_cache_seq_add(kv, 0, n_keep + n_discard, n_ctx, -n_discard);
204+
//llama_kv_cache_defrag (kv);
205+
llama_update_kv_cache (ctx, kv);
204206

205-
n_past = llama_kv_cache_seq_pos_max(ctx, 0) + 1;
207+
n_past = llama_kv_cache_seq_pos_max(kv, 0) + 1;
206208
}
207209
}
208210

0 commit comments

Comments
 (0)