Skip to content

Commit dad5c44

Browse files
authored
kv-cache : avoid modifying recurrent cells when setting inputs (#13834)
* kv-cache : avoid modifying recurrent cells when setting inputs * kv-cache : remove inp_s_mask It was replaced with equivalent and simpler functionality with rs_z (the first zeroed state) and the already-existing inp_s_copy. * kv-cache : fix non-consecutive token pos warning for recurrent models The problem was apparently caused by how the tail cells were swapped. * graph : simplify logic for recurrent state copies * kv-cache : use cell without src refs for rs_z in recurrent cache * llama-graph : fix recurrent state copy The `state_copy` shuffle assumes everything is moved at once, which is not true when `states_extra` is copied back to the cache before copying the range of states between `head` and `head + n_seqs`. This is only a problem if any of the cells in [`head`, `head + n_seqs`) have an `src` in [`head + n_seqs`, `head + n_kv`), which does happen when `n_ubatch > 1` in the `llama-parallel` example. Changing the order of the operations avoids the potential overwrite before use, although when copies are avoided (like with Mamba2), this will require further changes. * llama-graph : rename n_state to state_size in build_recurrent_state This naming should reduce confusion between the state size and the number of states.
1 parent 55f6b9f commit dad5c44

File tree

6 files changed

+117
-180
lines changed

6 files changed

+117
-180
lines changed

src/llama-graph.cpp

Lines changed: 30 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -250,22 +250,6 @@ void llm_graph_input_s_copy::set_input(const llama_ubatch * ubatch) {
250250
}
251251
}
252252

253-
void llm_graph_input_s_mask::set_input(const llama_ubatch * ubatch) {
254-
GGML_UNUSED(ubatch);
255-
256-
const int64_t n_kv = kv_state->get_n_kv();
257-
258-
if (s_mask) {
259-
GGML_ASSERT(ggml_backend_buffer_is_host(s_mask->buffer));
260-
float * data = (float *) s_mask->data;
261-
262-
// clear unused states
263-
for (int i = 0; i < n_kv; ++i) {
264-
data[i] = kv_state->s_mask(i);
265-
}
266-
}
267-
}
268-
269253
void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) {
270254
GGML_UNUSED(ubatch);
271255

@@ -987,23 +971,6 @@ ggml_tensor * llm_graph_context::build_inp_s_copy() const {
987971
return cur;
988972
}
989973

990-
ggml_tensor * llm_graph_context::build_inp_s_mask() const {
991-
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
992-
993-
auto inp = std::make_unique<llm_graph_input_s_mask>(kv_state);
994-
995-
const auto n_kv = kv_state->get_n_kv();
996-
997-
auto & cur = inp->s_mask;
998-
999-
cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 1, n_kv);
1000-
ggml_set_input(cur);
1001-
1002-
res->add_input(std::move(inp));
1003-
1004-
return cur;
1005-
}
1006-
1007974
ggml_tensor * llm_graph_context::build_inp_cross_embd() const {
1008975
auto inp = std::make_unique<llm_graph_input_cross_embd>(cross);
1009976

@@ -1456,43 +1423,53 @@ ggml_tensor * llm_graph_context::build_attn(
14561423
return cur;
14571424
}
14581425

1459-
ggml_tensor * llm_graph_context::build_copy_mask_state(
1426+
ggml_tensor * llm_graph_context::build_recurrent_state(
14601427
ggml_cgraph * gf,
14611428
ggml_tensor * s,
14621429
ggml_tensor * state_copy,
1463-
ggml_tensor * state_mask,
1464-
int32_t n_state,
1465-
int32_t n_seqs) const {
1430+
int32_t state_size,
1431+
int32_t n_seqs,
1432+
bool avoid_copies) const {
14661433
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
14671434

14681435
const auto n_kv = kv_state->get_n_kv();
14691436
const auto kv_head = kv_state->get_head();
1437+
const auto rs_zero = kv_state->get_rs_z();
14701438

1471-
ggml_tensor * states = ggml_reshape_2d(ctx0, s, n_state, kv_state->get_size());
1439+
ggml_tensor * states = ggml_reshape_2d(ctx0, s, state_size, kv_state->get_size());
14721440

1473-
// copy states
1474-
// NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv
1475-
// this shrinks the tensors's ne[1] to n_kv
1476-
states = ggml_get_rows(ctx0, states, state_copy);
1441+
// Clear a single state which will then be copied to the other cleared states.
1442+
// Note that this is a no-op when the view is zero-sized.
1443+
ggml_tensor * state_zero = ggml_view_1d(ctx0, states, state_size*(rs_zero >= 0), rs_zero*states->nb[1]*(rs_zero >= 0));
1444+
ggml_build_forward_expand(gf, ggml_scale_inplace(ctx0, state_zero, 0));
14771445

1478-
// clear states of sequences which are starting at the beginning of this batch
1479-
// FIXME: zero-out NANs?
1480-
states = ggml_mul(ctx0, states, state_mask);
1446+
ggml_tensor * output_states;
1447+
1448+
if (!avoid_copies) {
1449+
// copy states
1450+
// NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv
1451+
// {state_size, kv_size} -> {state_size, n_seqs}
1452+
output_states = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_seqs, 0));
1453+
ggml_build_forward_expand(gf, output_states);
1454+
} else {
1455+
// FIXME: make the gathering operation happen before the copy below
1456+
// (maybe with an optional lambda function passed as a parameter instead of `avoid_copies`?)
1457+
output_states = states;
1458+
}
14811459

1482-
// copy states which won't be changed further (between n_seqs and n_kv)
1460+
// copy extra states which won't be changed further (between n_seqs and n_kv)
1461+
ggml_tensor * states_extra = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_kv - n_seqs, n_seqs*state_copy->nb[0]));
14831462
ggml_build_forward_expand(gf,
14841463
ggml_cpy(ctx0,
1485-
ggml_view_1d(ctx0, states, n_state*(n_kv - n_seqs), (n_seqs )*n_state*ggml_element_size(states)),
1486-
ggml_view_1d(ctx0, s, n_state*(n_kv - n_seqs), (kv_head + n_seqs)*n_state*ggml_element_size(s))));
1464+
states_extra,
1465+
ggml_view_1d(ctx0, s, state_size*(n_kv - n_seqs), (kv_head + n_seqs)*state_size*ggml_element_size(s))));
14871466

1488-
// the part of the states that will be used and modified
1489-
return ggml_view_2d(ctx0, states, n_state, n_seqs, states->nb[1], 0);
1467+
return output_states;
14901468
}
14911469

14921470
ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
14931471
ggml_cgraph * gf,
14941472
ggml_tensor * state_copy,
1495-
ggml_tensor * state_mask,
14961473
const llama_ubatch & ubatch,
14971474
int il) const {
14981475
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
@@ -1503,8 +1480,8 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
15031480

15041481
ggml_tensor * token_shift_all = kv_state->get_k_l(il);
15051482

1506-
ggml_tensor * token_shift = build_copy_mask_state(
1507-
gf, token_shift_all, state_copy, state_mask,
1483+
ggml_tensor * token_shift = build_recurrent_state(
1484+
gf, token_shift_all, state_copy,
15081485
hparams.n_embd_k_s(), n_seqs);
15091486

15101487
token_shift = ggml_reshape_3d(ctx0, token_shift, hparams.n_embd, token_shift_count, n_seqs);

src/llama-graph.h

Lines changed: 4 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -200,18 +200,6 @@ class llm_graph_input_s_copy : public llm_graph_input_i {
200200
const llama_kv_cache_recurrent_state * kv_state;
201201
};
202202

203-
class llm_graph_input_s_mask : public llm_graph_input_i {
204-
public:
205-
llm_graph_input_s_mask(const llama_kv_cache_recurrent_state * kv_state) : kv_state(kv_state) {}
206-
virtual ~llm_graph_input_s_mask() = default;
207-
208-
void set_input(const llama_ubatch * ubatch) override;
209-
210-
ggml_tensor * s_mask; // F32 [1, n_kv]
211-
212-
const llama_kv_cache_recurrent_state * kv_state;
213-
};
214-
215203
class llm_graph_input_cross_embd : public llm_graph_input_i {
216204
public:
217205
llm_graph_input_cross_embd(
@@ -521,7 +509,6 @@ struct llm_graph_context {
521509
ggml_tensor * build_inp_mean() const;
522510
ggml_tensor * build_inp_cls() const;
523511
ggml_tensor * build_inp_s_copy() const;
524-
ggml_tensor * build_inp_s_mask() const;
525512

526513
ggml_tensor * build_inp_cross_embd() const;
527514
ggml_tensor * build_inp_pos_bucket_enc() const;
@@ -606,18 +593,17 @@ struct llm_graph_context {
606593
// recurrent
607594
//
608595

609-
ggml_tensor * build_copy_mask_state(
596+
ggml_tensor * build_recurrent_state(
610597
ggml_cgraph * gf,
611598
ggml_tensor * s,
612599
ggml_tensor * state_copy,
613-
ggml_tensor * state_mask,
614-
int32_t n_state,
615-
int32_t n_seqs) const;
600+
int32_t state_size,
601+
int32_t n_seqs,
602+
bool avoid_copies = false) const;
616603

617604
ggml_tensor * build_rwkv_token_shift_load(
618605
ggml_cgraph * gf,
619606
ggml_tensor * state_copy,
620-
ggml_tensor * state_mask,
621607
const llama_ubatch & ubatch,
622608
int il) const;
623609

src/llama-kv-cache-recurrent.cpp

Lines changed: 60 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -406,21 +406,12 @@ bool llama_kv_cache_recurrent::prepare(const std::vector<llama_ubatch> & ubatche
406406

407407
bool success = true;
408408

409-
// TODO: here we have to verify that all ubatches can fit in the cells
410-
// however, the current implementation is broken because it relies on s_copy() and s_mask() to update the cells
411-
// during the compute of each ubatch. to reproduce, uncomment the following loop and run:
412-
//
413-
// $ llama-parallel -m ./mamba-130m/ggml-model-f16.gguf -np 5 -ns 8
414-
//
415-
// recovery from failures when the batch does not fit in the KV cache will not work correctly until this is fixed
416-
//
417-
GGML_UNUSED(ubatches);
418-
//for (const auto & ubatch : ubatches) {
419-
// if (!find_slot(ubatch)) {
420-
// success = false;
421-
// break;
422-
// }
423-
//}
409+
for (const auto & ubatch : ubatches) {
410+
if (!find_slot(ubatch)) {
411+
success = false;
412+
break;
413+
}
414+
}
424415

425416
// restore the original state
426417
cells = std::move(org_cells);
@@ -431,14 +422,13 @@ bool llama_kv_cache_recurrent::prepare(const std::vector<llama_ubatch> & ubatche
431422
}
432423

433424
bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
434-
const uint32_t n_tokens = ubatch.n_tokens;
435-
const uint32_t n_seqs = ubatch.n_seqs;
425+
const uint32_t n_seqs = ubatch.n_seqs;
436426

437427
const uint32_t n_seq_tokens = ubatch.n_seq_tokens;
438428

439429
// if we have enough unused cells before the current head ->
440430
// better to start searching from the beginning of the cache, hoping to fill it
441-
if (head > used + 2*n_tokens) {
431+
if (head > used + 2*n_seqs) {
442432
head = 0;
443433
}
444434

@@ -534,16 +524,16 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
534524
empty_cell.src = orig_cell.src;
535525
orig_cell.seq_id.erase(seq_id);
536526
empty_cell.seq_id.insert(seq_id); // will be overwritten
527+
GGML_ASSERT(!orig_cell.is_empty()); // has at least one remaining seq_id
537528
}
538529
seq_meta.tail = next_empty_cell;
539530
// find next empty cell
540531
if (s + 1 < n_seqs) {
541-
next_empty_cell += 1;
542532
for (uint32_t i = 0; i < size; ++i) {
533+
next_empty_cell += 1;
543534
if (next_empty_cell >= size) { next_empty_cell -= size; }
544535
kv_cell & cell = cells[next_empty_cell];
545536
if (cell.is_empty()) { break; }
546-
next_empty_cell += 1;
547537
}
548538
}
549539
}
@@ -553,8 +543,8 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
553543

554544
// gather and re-order
555545
for (uint32_t s = 0; s < n_seqs; ++s) {
556-
int32_t dst_id = s + min;
557-
int32_t src_id = cells[ubatch.seq_id[s][0]].tail;
546+
const int32_t dst_id = s + min;
547+
const int32_t src_id = cells[ubatch.seq_id[s][0]].tail;
558548
if (dst_id != src_id) {
559549
kv_cell & dst_cell = cells[dst_id];
560550
kv_cell & src_cell = cells[src_id];
@@ -563,20 +553,22 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
563553
std::swap(dst_cell.src, src_cell.src);
564554
std::swap(dst_cell.seq_id, src_cell.seq_id);
565555

566-
// swap tails (assuming they NEVER overlap)
567-
for (const llama_seq_id seq_id : src_cell.seq_id) {
568-
cells[seq_id].tail = src_id;
569-
}
570-
for (const llama_seq_id seq_id : dst_cell.seq_id) {
571-
cells[seq_id].tail = dst_id;
556+
// swap tails
557+
for (uint32_t i = 0; i < size; ++i) {
558+
int32_t & tail = cells[i].tail;
559+
if (tail == src_id) {
560+
tail = dst_id;
561+
} else if (tail == dst_id) {
562+
tail = src_id;
563+
}
572564
}
573565
}
574566
}
575567

576568
// update the pos of the used seqs
577569
for (uint32_t s = 0; s < n_seqs; ++s) {
578570
const llama_pos last_pos = ubatch.pos[n_seq_tokens * s + n_seq_tokens - 1];
579-
int32_t cell_id = s + min;
571+
const int32_t cell_id = s + min;
580572
kv_cell & cell = cells[cell_id];
581573

582574
if (cell.pos >= 0 && last_pos != cell.pos + (llama_pos) n_seq_tokens) {
@@ -594,6 +586,38 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
594586
}
595587
}
596588

589+
// Find first cell without src refs, to use as the zero-ed state
590+
{
591+
// TODO: bake-in src refcounts in the cell metadata
592+
std::vector<int32_t> refcounts(size, 0);
593+
for (size_t i = 0; i < size; ++i) {
594+
const int32_t src = cells[i].src;
595+
if (src >= 0) {
596+
refcounts[src] += 1;
597+
}
598+
}
599+
600+
rs_z = -1;
601+
for (int i = min; i <= max; ++i) {
602+
if (refcounts[i] == 0) {
603+
rs_z = i;
604+
break;
605+
}
606+
}
607+
608+
for (int i = min; i <= max; ++i) {
609+
if (cells[i].src < 0) {
610+
GGML_ASSERT(rs_z >= 0);
611+
cells[i].src0 = rs_z;
612+
} else {
613+
// Stage the source ids for all used cells to allow correct seq_* behavior
614+
// and still make these values available when setting the inputs
615+
cells[i].src0 = cells[i].src;
616+
}
617+
cells[i].src = i; // avoid moving or clearing twice
618+
}
619+
}
620+
597621
// allow getting the range of used cells, from head to head + n
598622
head = min;
599623
n = max - min + 1;
@@ -605,47 +629,8 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
605629
}
606630

607631
bool llama_kv_cache_recurrent::get_can_shift() const {
608-
return false;
609-
}
610-
611-
int32_t llama_kv_cache_recurrent::s_copy(int i) const {
612-
const uint32_t cell_id = i + head;
613-
614-
//////////////////////////////////////////////
615-
// TODO: this should not mutate the KV cache !
616-
kv_cell & cell = const_cast<kv_cell &>(cells[cell_id]);
617-
618-
// prevent out-of-bound sources
619-
if (cell.src < 0 || (uint32_t) cell.src >= size) {
620-
cell.src = cell_id;
621-
}
622-
623-
int32_t res = cell.src;
624-
625-
// TODO: do not mutate the KV cache
626-
// ensure copy only happens once
627-
if (cell.src != (int32_t) cell_id) {
628-
cell.src = cell_id;
629-
}
630-
631-
return res;
632-
}
633-
634-
float llama_kv_cache_recurrent::s_mask(int i) const {
635-
const uint32_t cell_id = i + head;
636-
637-
//////////////////////////////////////////////
638-
// TODO: this should not mutate the KV cache !
639-
kv_cell & cell = const_cast<kv_cell &>(cells[cell_id]);
640-
641-
float res = (float) (cell.src >= 0);
642-
643-
// only clear once
644-
if (cell.src < 0) {
645-
cell.src = cell_id;
646-
}
647-
648-
return res;
632+
// shifting the pos is trivial for recurrent models
633+
return true;
649634
}
650635

651636
size_t llama_kv_cache_recurrent::total_size() const {
@@ -1111,6 +1096,10 @@ uint32_t llama_kv_cache_recurrent_state::get_head() const {
11111096
return is_full ? 0 : kv->head;
11121097
}
11131098

1099+
int32_t llama_kv_cache_recurrent_state::get_rs_z() const {
1100+
return is_full ? 0 : kv->rs_z;
1101+
}
1102+
11141103
uint32_t llama_kv_cache_recurrent_state::get_size() const {
11151104
return kv->size;
11161105
}
@@ -1124,9 +1113,5 @@ ggml_tensor * llama_kv_cache_recurrent_state::get_v_l(int32_t il) const {
11241113
}
11251114

11261115
int32_t llama_kv_cache_recurrent_state::s_copy(int i) const {
1127-
return kv->s_copy(i);
1128-
}
1129-
1130-
float llama_kv_cache_recurrent_state::s_mask(int i) const {
1131-
return kv->s_mask(i);
1116+
return kv->cells[i + kv->head].src0;
11321117
}

0 commit comments

Comments
 (0)