Skip to content

Commit 85dbd6f

Browse files
committed
kv-cache : simplify interface (wip)
ggml-ci
1 parent b2ef3ae commit 85dbd6f

File tree

4 files changed

+108
-93
lines changed

4 files changed

+108
-93
lines changed

src/llama-context.cpp

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1108,7 +1108,7 @@ int llama_context::decode(llama_batch & inp_batch) {
11081108

11091109
// decide if we need to defrag the kv cache
11101110
if (cparams.defrag_thold > 0.0f) {
1111-
kv_self->defrag(cparams.defrag_thold);
1111+
kv_self->defrag_sched(cparams.defrag_thold);
11121112
}
11131113

11141114
// Reset state for the next token before backend sync, to allow the CPU activities in the reset to
@@ -2152,7 +2152,7 @@ void llama_kv_cache_seq_cp(
21522152
llama_seq_id seq_id_dst,
21532153
llama_pos p0,
21542154
llama_pos p1) {
2155-
return llama_kv_self_seq_cp(ctx, seq_id_src, seq_id_dst, p0, p1);
2155+
llama_kv_self_seq_cp(ctx, seq_id_src, seq_id_dst, p0, p1);
21562156
}
21572157

21582158
void llama_kv_self_seq_cp(
@@ -2166,14 +2166,14 @@ void llama_kv_self_seq_cp(
21662166
return;
21672167
}
21682168

2169-
return kv->seq_cp(seq_id_src, seq_id_dst, p0, p1);
2169+
kv->seq_cp(seq_id_src, seq_id_dst, p0, p1);
21702170
}
21712171

21722172
// deprecated
21732173
void llama_kv_cache_seq_keep(
21742174
llama_context * ctx,
21752175
llama_seq_id seq_id) {
2176-
return llama_kv_self_seq_keep(ctx, seq_id);
2176+
llama_kv_self_seq_keep(ctx, seq_id);
21772177
}
21782178

21792179
void llama_kv_self_seq_keep(llama_context * ctx, llama_seq_id seq_id) {
@@ -2182,7 +2182,7 @@ void llama_kv_self_seq_keep(llama_context * ctx, llama_seq_id seq_id) {
21822182
return;
21832183
}
21842184

2185-
return kv->seq_keep(seq_id);
2185+
kv->seq_keep(seq_id);
21862186
}
21872187

21882188
// deprecated
@@ -2192,7 +2192,7 @@ void llama_kv_cache_seq_add(
21922192
llama_pos p0,
21932193
llama_pos p1,
21942194
llama_pos delta) {
2195-
return llama_kv_self_seq_add(ctx, seq_id, p0, p1, delta);
2195+
llama_kv_self_seq_add(ctx, seq_id, p0, p1, delta);
21962196
}
21972197

21982198
void llama_kv_self_seq_add(
@@ -2206,7 +2206,7 @@ void llama_kv_self_seq_add(
22062206
return;
22072207
}
22082208

2209-
return kv->seq_add(seq_id, p0, p1, delta);
2209+
kv->seq_add(seq_id, p0, p1, delta);
22102210
}
22112211

22122212
// deprecated
@@ -2216,7 +2216,7 @@ void llama_kv_cache_seq_div(
22162216
llama_pos p0,
22172217
llama_pos p1,
22182218
int d) {
2219-
return llama_kv_self_seq_div(ctx, seq_id, p0, p1, d);
2219+
llama_kv_self_seq_div(ctx, seq_id, p0, p1, d);
22202220
}
22212221

22222222
void llama_kv_self_seq_div(
@@ -2230,7 +2230,7 @@ void llama_kv_self_seq_div(
22302230
return;
22312231
}
22322232

2233-
return kv->seq_div(seq_id, p0, p1, d);
2233+
kv->seq_div(seq_id, p0, p1, d);
22342234
}
22352235

22362236
// deprecated
@@ -2249,7 +2249,7 @@ llama_pos llama_kv_self_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) {
22492249

22502250
// deprecated
22512251
void llama_kv_cache_defrag(llama_context * ctx) {
2252-
return llama_kv_self_defrag(ctx);
2252+
llama_kv_self_defrag(ctx);
22532253
}
22542254

22552255
void llama_kv_self_defrag(llama_context * ctx) {
@@ -2259,7 +2259,7 @@ void llama_kv_self_defrag(llama_context * ctx) {
22592259
}
22602260

22612261
// force defrag
2262-
return kv->defrag(-1.0f);
2262+
kv->defrag_sched(-1.0f);
22632263
}
22642264

22652265
// deprecated

src/llama-graph.cpp

Lines changed: 2 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -270,24 +270,7 @@ void llm_graph_input_s_copy::set_input(const llama_ubatch * ubatch) {
270270

271271
// assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
272272
for (uint32_t i = 0; i < n_kv; ++i) {
273-
const uint32_t cell_id = i + kv_self->head;
274-
275-
//////////////////////////////////////////////
276-
// TODO: this should not mutate the KV cache !
277-
llama_kv_cell & kv_cell = const_cast<class llama_kv_cache_recurrent *>(kv_self)->cells[i];
278-
279-
// prevent out-of-bound sources
280-
if (kv_cell.src < 0 || (uint32_t) kv_cell.src >= kv_self->size) {
281-
kv_cell.src = cell_id;
282-
}
283-
284-
data[i] = kv_cell.src;
285-
286-
// TODO: do not mutate the KV cache
287-
// ensure copy only happens once
288-
if (kv_cell.src != (int32_t) cell_id) {
289-
kv_cell.src = cell_id;
290-
}
273+
data[i] = kv_self->s_copy(i);
291274
}
292275
}
293276
}
@@ -303,18 +286,7 @@ void llm_graph_input_s_mask::set_input(const llama_ubatch * ubatch) {
303286

304287
// clear unused states
305288
for (int i = 0; i < n_kv; ++i) {
306-
const uint32_t cell_id = i + kv_self->head;
307-
308-
//////////////////////////////////////////////
309-
// TODO: this should not mutate the KV cache !
310-
llama_kv_cell & kv_cell = const_cast<class llama_kv_cache_recurrent *>(kv_self)->cells[i];
311-
312-
data[i] = (float) (kv_cell.src >= 0);
313-
314-
// only clear once
315-
if (kv_cell.src < 0) {
316-
kv_cell.src = cell_id;
317-
}
289+
data[i] = kv_self->s_mask(i);
318290
}
319291
}
320292
}

src/llama-kv-cache.cpp

Lines changed: 45 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -130,14 +130,6 @@ int32_t llama_kv_cache_unified::get_used_cells() const {
130130
return used;
131131
}
132132

133-
bool llama_kv_cache_unified::get_has_shift() const {
134-
return has_shift;
135-
}
136-
137-
bool llama_kv_cache_unified::get_do_defrag() const {
138-
return do_defrag;
139-
}
140-
141133
size_t llama_kv_cache_unified::total_size() const {
142134
size_t size = 0;
143135
for (const auto & buf : bufs) {
@@ -358,10 +350,10 @@ llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const {
358350
return result;
359351
}
360352

361-
void llama_kv_cache_unified::defrag(float thold) {
353+
void llama_kv_cache_unified::defrag_sched(float thold) {
362354
// - do not defrag small contexts (i.e. < 2048 tokens)
363355
// - count the padding towards the number of used tokens
364-
const float fragmentation = n >= 2048 ? std::max(0.0f, 1.0f - float(used + padding)/float(n)) : 0.0f;
356+
const float fragmentation = n >= 2048 ? std::max(0.0f, 1.0f - (float(used + padding)/n)) : 0.0f;
365357

366358
// queue defragmentation for next llama_kv_cache_update
367359
if (fragmentation > thold) {
@@ -699,7 +691,7 @@ bool llama_kv_cache_unified::update(const graph_params & params) {
699691

700692
const auto & sched = params.sched;
701693

702-
if (get_has_shift()) {
694+
if (has_shift) {
703695
if (!get_can_shift()) {
704696
GGML_ABORT("The current KV cache / model configuration does not support K-shift");
705697
}
@@ -732,7 +724,7 @@ bool llama_kv_cache_unified::update(const graph_params & params) {
732724
}
733725
}
734726

735-
if (get_do_defrag()) {
727+
if (do_defrag) {
736728
LLAMA_LOG_DEBUG("%s: defragmenting KV cache\n", __func__);
737729

738730
if (defrag_prepare(params.n_max_nodes)) {
@@ -1496,14 +1488,6 @@ int32_t llama_kv_cache_recurrent::get_used_cells() const {
14961488
return used;
14971489
}
14981490

1499-
bool llama_kv_cache_recurrent::get_has_shift() const {
1500-
return false;
1501-
}
1502-
1503-
bool llama_kv_cache_recurrent::get_do_defrag() const {
1504-
return false;
1505-
}
1506-
15071491
size_t llama_kv_cache_recurrent::total_size() const {
15081492
size_t size = 0;
15091493
for (const auto & buf : bufs) {
@@ -1716,7 +1700,7 @@ llama_pos llama_kv_cache_recurrent::seq_pos_max(llama_seq_id seq_id) const {
17161700
return result;
17171701
}
17181702

1719-
void llama_kv_cache_recurrent::defrag(float thold) {
1703+
void llama_kv_cache_recurrent::defrag_sched(float thold) {
17201704
GGML_UNUSED(thold);
17211705
// noop
17221706
}
@@ -1742,6 +1726,46 @@ bool llama_kv_cache_recurrent::get_can_shift() const {
17421726
return false;
17431727
}
17441728

1729+
int32_t llama_kv_cache_recurrent::s_copy(int i) const {
1730+
const uint32_t cell_id = i + head;
1731+
1732+
//////////////////////////////////////////////
1733+
// TODO: this should not mutate the KV cache !
1734+
llama_kv_cell & kv_cell = const_cast<llama_kv_cell &>(cells[i]);
1735+
1736+
// prevent out-of-bound sources
1737+
if (kv_cell.src < 0 || (uint32_t) kv_cell.src >= size) {
1738+
kv_cell.src = cell_id;
1739+
}
1740+
1741+
int32_t res = kv_cell.src;
1742+
1743+
// TODO: do not mutate the KV cache
1744+
// ensure copy only happens once
1745+
if (kv_cell.src != (int32_t) cell_id) {
1746+
kv_cell.src = cell_id;
1747+
}
1748+
1749+
return res;
1750+
}
1751+
1752+
float llama_kv_cache_recurrent::s_mask(int i) const {
1753+
const uint32_t cell_id = i + head;
1754+
1755+
//////////////////////////////////////////////
1756+
// TODO: this should not mutate the KV cache !
1757+
llama_kv_cell & kv_cell = const_cast<llama_kv_cell &>(cells[i]);
1758+
1759+
float res = (float) (kv_cell.src >= 0);
1760+
1761+
// only clear once
1762+
if (kv_cell.src < 0) {
1763+
kv_cell.src = cell_id;
1764+
}
1765+
1766+
return res;
1767+
}
1768+
17451769
bool llama_kv_cache_recurrent::find_slot(
17461770
const llama_ubatch & ubatch) {
17471771
const uint32_t n_tokens = ubatch.n_tokens;

0 commit comments

Comments
 (0)