Skip to content

Commit e7f94f8

Browse files
committed
llama : update llama_kv_self API
ggml-ci
1 parent fb74024 commit e7f94f8

File tree

30 files changed

+386
-203
lines changed

30 files changed

+386
-203
lines changed

common/common.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -893,9 +893,7 @@ struct common_init_result common_init_from_params(common_params & params) {
893893
return iparams;
894894
}
895895

896-
llama_kv_cache * kv = llama_get_kv_cache(lctx);
897-
898-
if (params.ctx_shift && !llama_kv_cache_can_shift(kv)) {
896+
if (params.ctx_shift && !llama_kv_self_can_shift(lctx)) {
899897
LOG_WRN("%s: KV cache shifting is not supported for this model, disabling KV cache shifting\n", __func__);
900898
params.ctx_shift = false;
901899
}
@@ -1000,7 +998,7 @@ struct common_init_result common_init_from_params(common_params & params) {
1000998
if (llama_model_has_decoder(model)) {
1001999
llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch)));
10021000
}
1003-
llama_kv_cache_clear(kv);
1001+
llama_kv_self_clear(lctx);
10041002
llama_synchronize(lctx);
10051003
llama_perf_context_reset(lctx);
10061004
}

common/speculative.cpp

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -171,10 +171,8 @@ llama_tokens common_speculative_gen_draft(
171171
llama_tokens result;
172172
result.reserve(params.n_draft);
173173

174-
llama_kv_cache * kv = llama_get_kv_cache(ctx);
175-
176174
if (reuse_n == 0) {
177-
llama_kv_cache_clear(kv);
175+
llama_kv_self_clear(ctx);
178176

179177
prompt.clear();
180178
} else {
@@ -193,14 +191,14 @@ llama_tokens common_speculative_gen_draft(
193191
}
194192

195193
if (reuse_i > 0) {
196-
llama_kv_cache_seq_rm (kv, 0, 0, reuse_i);
197-
llama_kv_cache_seq_add(kv, 0, reuse_i, -1, -reuse_i);
194+
llama_kv_self_seq_rm (ctx, 0, 0, reuse_i);
195+
llama_kv_self_seq_add(ctx, 0, reuse_i, -1, -reuse_i);
198196

199197
prompt.erase(prompt.begin(), prompt.begin() + reuse_i);
200198
}
201199

202200
if (reuse_n < (int) prompt.size()) {
203-
llama_kv_cache_seq_rm (kv, 0, reuse_n, -1);
201+
llama_kv_self_seq_rm (ctx, 0, reuse_n, -1);
204202

205203
prompt.erase(prompt.begin() + reuse_n, prompt.end());
206204
}

examples/batched-bench/batched-bench.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,6 @@ int main(int argc, char ** argv) {
5757
return 1;
5858
}
5959

60-
llama_kv_cache * kv = llama_get_kv_cache(ctx);
61-
6260
const int32_t n_kv_max = llama_n_ctx(ctx);
6361

6462
llama_batch batch = llama_batch_init(n_kv_max, 0, 1);
@@ -134,7 +132,7 @@ int main(int argc, char ** argv) {
134132

135133
const auto t_pp_start = ggml_time_us();
136134

137-
llama_kv_cache_clear(kv);
135+
llama_kv_self_clear(ctx);
138136

139137
if (!decode_helper(ctx, batch, ctx_params.n_batch)) {
140138
LOG_ERR("%s: llama_decode() failed\n", __func__);
@@ -143,7 +141,7 @@ int main(int argc, char ** argv) {
143141

144142
if (is_pp_shared) {
145143
for (int32_t i = 1; i < pl; ++i) {
146-
llama_kv_cache_seq_cp(kv, 0, i, -1, -1);
144+
llama_kv_self_seq_cp(ctx, 0, i, -1, -1);
147145
}
148146
}
149147

examples/batched.swift/Sources/main.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ if llama_decode(context, batch) != 0 {
111111
}
112112

113113
for i in 1 ..< n_parallel {
114-
llama_kv_cache_seq_cp(context, 0, Int32(i), 0, batch.n_tokens)
114+
llama_kv_self_seq_cp(context, 0, Int32(i), 0, batch.n_tokens)
115115
}
116116

117117
if n_parallel > 1 {

examples/cvector-generator/cvector-generator.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -342,8 +342,7 @@ static bool cb_eval(struct ggml_tensor * t, bool ask, void * user_data) {
342342
}
343343

344344
static bool get_hidden_layers(llama_context * ctx, std::vector<llama_token> & tokens) {
345-
llama_kv_cache * kv = llama_get_kv_cache(ctx);
346-
llama_kv_cache_clear(kv);
345+
llama_kv_self_clear(ctx);
347346
if (llama_decode(ctx, llama_batch_get_one(tokens.data(), tokens.size()))) {
348347
fprintf(stderr, "%s : failed to eval\n", __func__);
349348
return false;

examples/embedding/embedding.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,10 @@ static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & toke
3434

3535
static void batch_decode(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd, int embd_norm) {
3636
const enum llama_pooling_type pooling_type = llama_pooling_type(ctx);
37-
const llama_model * model = llama_get_model(ctx);
38-
llama_kv_cache * kv = llama_get_kv_cache(ctx);
37+
const struct llama_model * model = llama_get_model(ctx);
3938

4039
// clear previous kv_cache values (irrelevant for embeddings)
41-
llama_kv_cache_clear(kv);
40+
llama_kv_self_clear(ctx);
4241

4342
// run model
4443
LOG_INF("%s: n_tokens = %d, n_seq = %d\n", __func__, batch.n_tokens, n_seq);

examples/gritlm/gritlm.cpp

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,6 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve
1313
const llama_model * model = llama_get_model(ctx);
1414
const llama_vocab * vocab = llama_model_get_vocab(model);
1515

16-
llama_kv_cache * kv = llama_get_kv_cache(ctx);
17-
1816
llama_batch batch = llama_batch_init(llama_n_batch(ctx), 0, 1);
1917

2018
for (uint64_t i = 0; i < sentences.size(); i++) {
@@ -47,7 +45,7 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve
4745
}
4846

4947
// clear previous kv_cache values (irrelevant for embeddings)
50-
llama_kv_cache_clear(kv);
48+
llama_kv_self_clear(ctx);
5149
llama_set_embeddings(ctx, true);
5250
llama_set_causal_attn(ctx, false);
5351

@@ -102,11 +100,9 @@ static std::string generate(llama_context * ctx, llama_sampler * smpl, const std
102100
const llama_model * model = llama_get_model(ctx);
103101
const llama_vocab * vocab = llama_model_get_vocab(model);
104102

105-
llama_kv_cache * kv = llama_get_kv_cache(ctx);
106-
107103
llama_token eos_token = llama_vocab_eos(vocab);
108104

109-
llama_kv_cache_clear(kv);
105+
llama_kv_self_clear(ctx);
110106
llama_set_embeddings(ctx, false);
111107
llama_set_causal_attn(ctx, true);
112108

examples/imatrix/imatrix.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -431,8 +431,6 @@ static bool compute_imatrix(llama_context * ctx, const common_params & params) {
431431
const llama_model * model = llama_get_model(ctx);
432432
const llama_vocab * vocab = llama_model_get_vocab(model);
433433

434-
llama_kv_cache * kv = llama_get_kv_cache(ctx);
435-
436434
const bool add_bos = llama_vocab_get_add_bos(vocab);
437435
const int n_ctx = llama_n_ctx(ctx);
438436

@@ -499,7 +497,7 @@ static bool compute_imatrix(llama_context * ctx, const common_params & params) {
499497
const auto t_start = std::chrono::high_resolution_clock::now();
500498

501499
// clear the KV cache
502-
llama_kv_cache_clear(kv);
500+
llama_kv_self_clear(ctx);
503501

504502
llama_batch batch = llama_batch_init(n_batch, 0, 1);
505503

examples/infill/infill.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -139,8 +139,6 @@ int main(int argc, char ** argv) {
139139
return 1;
140140
}
141141

142-
llama_kv_cache * kv = llama_get_kv_cache(ctx);
143-
144142
const llama_vocab * vocab = llama_model_get_vocab(model);
145143

146144
const int n_ctx_train = llama_model_n_ctx_train(model);
@@ -334,8 +332,8 @@ int main(int argc, char ** argv) {
334332
LOG_DBG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n",
335333
n_past, n_left, n_ctx, params.n_keep, n_discard);
336334

337-
llama_kv_cache_seq_rm (kv, 0, params.n_keep + 1 , params.n_keep + n_discard + 1);
338-
llama_kv_cache_seq_add(kv, 0, params.n_keep + 1 + n_discard, n_past, -n_discard);
335+
llama_kv_self_seq_rm (ctx, 0, params.n_keep + 1 , params.n_keep + n_discard + 1);
336+
llama_kv_self_seq_add(ctx, 0, params.n_keep + 1 + n_discard, n_past, -n_discard);
339337

340338
n_past -= n_discard;
341339

examples/llama-bench/llama-bench.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1546,11 +1546,9 @@ int main(int argc, char ** argv) {
15461546
return 1;
15471547
}
15481548

1549-
llama_kv_cache * kv = llama_get_kv_cache(ctx);
1550-
15511549
test t(inst, lmodel, ctx);
15521550

1553-
llama_kv_cache_clear(kv);
1551+
llama_kv_self_clear(ctx);
15541552

15551553
// cool off before the test
15561554
if (params.delay) {
@@ -1590,7 +1588,7 @@ int main(int argc, char ** argv) {
15901588
}
15911589

15921590
for (int i = 0; i < params.reps; i++) {
1593-
llama_kv_cache_clear(kv);
1591+
llama_kv_self_clear(ctx);
15941592

15951593
uint64_t t_start = get_time_ns();
15961594

examples/llama.android/llama/src/main/cpp/llama-android.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ Java_android_llama_cpp_LLamaAndroid_bench_1model(
194194
}
195195

196196
batch->logits[batch->n_tokens - 1] = true;
197-
llama_kv_cache_clear(context);
197+
llama_kv_self_clear(context);
198198

199199
const auto t_pp_start = ggml_time_us();
200200
if (llama_decode(context, *batch) != 0) {
@@ -206,7 +206,7 @@ Java_android_llama_cpp_LLamaAndroid_bench_1model(
206206

207207
LOGi("Benchmark text generation (tg)");
208208

209-
llama_kv_cache_clear(context);
209+
llama_kv_self_clear(context);
210210
const auto t_tg_start = ggml_time_us();
211211
for (i = 0; i < tg; i++) {
212212

@@ -223,7 +223,7 @@ Java_android_llama_cpp_LLamaAndroid_bench_1model(
223223

224224
const auto t_tg_end = ggml_time_us();
225225

226-
llama_kv_cache_clear(context);
226+
llama_kv_self_clear(context);
227227

228228
const auto t_pp = double(t_pp_end - t_pp_start) / 1000000.0;
229229
const auto t_tg = double(t_tg_end - t_tg_start) / 1000000.0;
@@ -446,5 +446,5 @@ Java_android_llama_cpp_LLamaAndroid_completion_1loop(
446446
extern "C"
447447
JNIEXPORT void JNICALL
448448
Java_android_llama_cpp_LLamaAndroid_kv_1cache_1clear(JNIEnv *, jobject, jlong context) {
449-
llama_kv_cache_clear(reinterpret_cast<llama_context *>(context));
449+
llama_kv_self_clear(reinterpret_cast<llama_context *>(context));
450450
}

examples/llama.swiftui/llama.cpp.swift/LibLlama.swift

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ actor LlamaContext {
208208
}
209209
batch.logits[Int(batch.n_tokens) - 1] = 1 // true
210210

211-
llama_kv_cache_clear(context)
211+
llama_kv_self_clear(context)
212212

213213
let t_pp_start = DispatchTime.now().uptimeNanoseconds / 1000;
214214

@@ -221,7 +221,7 @@ actor LlamaContext {
221221

222222
// bench text generation
223223

224-
llama_kv_cache_clear(context)
224+
llama_kv_self_clear(context)
225225

226226
let t_tg_start = DispatchTime.now().uptimeNanoseconds / 1000;
227227

@@ -240,7 +240,7 @@ actor LlamaContext {
240240

241241
let t_tg_end = DispatchTime.now().uptimeNanoseconds / 1000;
242242

243-
llama_kv_cache_clear(context)
243+
llama_kv_self_clear(context)
244244

245245
let t_pp = Double(t_pp_end - t_pp_start) / 1000000.0
246246
let t_tg = Double(t_tg_end - t_tg_start) / 1000000.0
@@ -290,7 +290,7 @@ actor LlamaContext {
290290
func clear() {
291291
tokens_list.removeAll()
292292
temporary_invalid_cchars.removeAll()
293-
llama_kv_cache_clear(context)
293+
llama_kv_self_clear(context)
294294
}
295295

296296
private func tokenize(text: String, add_bos: Bool) -> [llama_token] {

examples/lookahead/lookahead.cpp

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,6 @@ int main(int argc, char ** argv) {
6060

6161
llama_model * model = llama_init.model.get();
6262
llama_context * ctx = llama_init.context.get();
63-
llama_kv_cache * kv = llama_get_kv_cache(ctx);
6463

6564
const llama_vocab * vocab = llama_model_get_vocab(model);
6665

@@ -96,7 +95,7 @@ int main(int argc, char ** argv) {
9695
llama_decode(ctx, llama_batch_get_one(&inp.back(), 1));
9796

9897
for (int s = 1; s < W + G + 1; ++s) {
99-
llama_kv_cache_seq_cp(kv, 0, s, -1, -1);
98+
llama_kv_self_seq_cp(ctx, 0, s, -1, -1);
10099
}
101100

102101
const auto t_enc_end = ggml_time_us();
@@ -438,17 +437,17 @@ int main(int argc, char ** argv) {
438437

439438
// KV cache management
440439
// if no verification token matched, we simply remove all cells from this batch -> no fragmentation
441-
llama_kv_cache_seq_rm(kv, -1, n_past, -1);
440+
llama_kv_self_seq_rm(ctx, -1, n_past, -1);
442441

443442
if (seq_id_best != 0) {
444443
// if a verification token matched, we keep the best sequence and remove the rest
445444
// this leads to some KV cache fragmentation
446-
llama_kv_cache_seq_keep(kv, seq_id_best);
447-
llama_kv_cache_seq_cp (kv, seq_id_best, 0, -1, -1);
448-
llama_kv_cache_seq_rm (kv, seq_id_best, -1, -1);
445+
llama_kv_self_seq_keep(ctx, seq_id_best);
446+
llama_kv_self_seq_cp (ctx, seq_id_best, 0, -1, -1);
447+
llama_kv_self_seq_rm (ctx, seq_id_best, -1, -1);
449448

450449
for (int s = 1; s < W + G + 1; ++s) {
451-
llama_kv_cache_seq_cp(kv, 0, s, -1, -1);
450+
llama_kv_self_seq_cp(ctx, 0, s, -1, -1);
452451
}
453452
}
454453
}

examples/lookup/lookup.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@ int main(int argc, char ** argv){
3535

3636
llama_model * model = llama_init.model.get();
3737
llama_context * ctx = llama_init.context.get();
38-
llama_kv_cache * kv = llama_get_kv_cache(ctx);
3938

4039
const llama_vocab * vocab = llama_model_get_vocab(model);
4140

@@ -193,7 +192,7 @@ int main(int argc, char ** argv){
193192

194193
// KV cache management
195194
// clean the cache of draft tokens that weren't accepted
196-
llama_kv_cache_seq_rm(kv, 0, n_past, -1);
195+
llama_kv_self_seq_rm(ctx, 0, n_past, -1);
197196

198197
common_batch_clear(batch_tgt);
199198
common_batch_add(batch_tgt, draft[0], n_past, { 0 }, true);

examples/main/main.cpp

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -162,8 +162,6 @@ int main(int argc, char ** argv) {
162162
return 1;
163163
}
164164

165-
llama_kv_cache * kv = llama_get_kv_cache(ctx);
166-
167165
const llama_vocab * vocab = llama_model_get_vocab(model);
168166

169167
LOG_INF("%s: llama threadpool init, n_threads = %d\n", __func__, (int) params.cpuparams.n_threads);
@@ -308,7 +306,7 @@ int main(int argc, char ** argv) {
308306
}
309307

310308
// remove any "future" tokens that we might have inherited from the previous session
311-
llama_kv_cache_seq_rm(kv, -1, n_matching_session_tokens, -1);
309+
llama_kv_self_seq_rm(ctx, -1, n_matching_session_tokens, -1);
312310
}
313311

314312
LOG_DBG("recalculate the cached logits (check): embd_inp.size() %zu, n_matching_session_tokens %zu, embd_inp.size() %zu, session_tokens.size() %zu\n",
@@ -545,8 +543,8 @@ int main(int argc, char ** argv) {
545543
LOG_DBG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n",
546544
n_past, n_left, n_ctx, params.n_keep, n_discard);
547545

548-
llama_kv_cache_seq_rm (kv, 0, params.n_keep , params.n_keep + n_discard);
549-
llama_kv_cache_seq_add(kv, 0, params.n_keep + n_discard, n_past, -n_discard);
546+
llama_kv_self_seq_rm (ctx, 0, params.n_keep , params.n_keep + n_discard);
547+
llama_kv_self_seq_add(ctx, 0, params.n_keep + n_discard, n_past, -n_discard);
550548

551549
n_past -= n_discard;
552550

@@ -569,9 +567,9 @@ int main(int argc, char ** argv) {
569567
LOG_DBG("div: [%6d, %6d] / %6d -> [%6d, %6d]\n", ga_i + ib*bd, ga_i + ib*bd + ga_w, ga_n, (ga_i + ib*bd)/ga_n, (ga_i + ib*bd + ga_w)/ga_n);
570568
LOG_DBG("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", ga_i + ib*bd + ga_w, n_past + ib*bd, dd, ga_i + ib*bd + ga_w + dd, n_past + ib*bd + dd);
571569

572-
llama_kv_cache_seq_add(kv, 0, ga_i, n_past, ib*bd);
573-
llama_kv_cache_seq_div(kv, 0, ga_i + ib*bd, ga_i + ib*bd + ga_w, ga_n);
574-
llama_kv_cache_seq_add(kv, 0, ga_i + ib*bd + ga_w, n_past + ib*bd, dd);
570+
llama_kv_self_seq_add(ctx, 0, ga_i, n_past, ib*bd);
571+
llama_kv_self_seq_div(ctx, 0, ga_i + ib*bd, ga_i + ib*bd + ga_w, ga_n);
572+
llama_kv_self_seq_add(ctx, 0, ga_i + ib*bd + ga_w, n_past + ib*bd, dd);
575573

576574
n_past -= bd;
577575

0 commit comments

Comments
 (0)