Skip to content

Commit 3129639

Browse files
committed
kv-cache : avoid modifying recurrent cells when setting inputs
* kv-cache : remove inp_s_mask It was replaced with equivalent and simpler functionality with rs_z (the first zeroed state) and the already-existing inp_s_copy. * kv-cache : fix non-consecutive token pos warning for recurrent models The problem was apparently caused by how the tail cells were swapped. * graph : simplify logic for recurrent state copies * kv-cache : use cell without src refs for rs_z in recurrent cache
1 parent b3a89c3 commit 3129639

File tree

5 files changed

+108
-177
lines changed

5 files changed

+108
-177
lines changed

src/llama-graph.cpp

Lines changed: 22 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -247,22 +247,6 @@ void llm_graph_input_s_copy::set_input(const llama_ubatch * ubatch) {
247247
}
248248
}
249249

250-
void llm_graph_input_s_mask::set_input(const llama_ubatch * ubatch) {
251-
GGML_UNUSED(ubatch);
252-
253-
const int64_t n_kv = kv_state->get_n_kv();
254-
255-
if (s_mask) {
256-
GGML_ASSERT(ggml_backend_buffer_is_host(s_mask->buffer));
257-
float * data = (float *) s_mask->data;
258-
259-
// clear unused states
260-
for (int i = 0; i < n_kv; ++i) {
261-
data[i] = kv_state->s_mask(i);
262-
}
263-
}
264-
}
265-
266250
void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) {
267251
GGML_UNUSED(ubatch);
268252

@@ -970,23 +954,6 @@ ggml_tensor * llm_graph_context::build_inp_s_copy() const {
970954
return cur;
971955
}
972956

973-
ggml_tensor * llm_graph_context::build_inp_s_mask() const {
974-
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
975-
976-
auto inp = std::make_unique<llm_graph_input_s_mask>(kv_state);
977-
978-
const auto n_kv = kv_state->get_n_kv();
979-
980-
auto & cur = inp->s_mask;
981-
982-
cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 1, n_kv);
983-
ggml_set_input(cur);
984-
985-
res->add_input(std::move(inp));
986-
987-
return cur;
988-
}
989-
990957
ggml_tensor * llm_graph_context::build_inp_cross_embd() const {
991958
auto inp = std::make_unique<llm_graph_input_cross_embd>(cross);
992959

@@ -1439,43 +1406,46 @@ ggml_tensor * llm_graph_context::build_attn(
14391406
return cur;
14401407
}
14411408

1442-
ggml_tensor * llm_graph_context::build_copy_mask_state(
1409+
ggml_tensor * llm_graph_context::build_recurrent_state(
14431410
ggml_cgraph * gf,
14441411
ggml_tensor * s,
14451412
ggml_tensor * state_copy,
1446-
ggml_tensor * state_mask,
14471413
int32_t n_state,
1448-
int32_t n_seqs) const {
1414+
int32_t n_seqs,
1415+
bool avoid_copies) const {
14491416
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
14501417

14511418
const auto n_kv = kv_state->get_n_kv();
14521419
const auto kv_head = kv_state->get_head();
1420+
const auto rs_zero = kv_state->get_rs_z();
14531421

14541422
ggml_tensor * states = ggml_reshape_2d(ctx0, s, n_state, kv_state->get_size());
14551423

1456-
// copy states
1457-
// NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv
1458-
// this shrinks the tensors's ne[1] to n_kv
1459-
states = ggml_get_rows(ctx0, states, state_copy);
1424+
// Clear a single state which will then be copied to the other cleared states.
1425+
// Note that this is a no-op when the view is zero-sized.
1426+
ggml_tensor * state_zero = ggml_view_1d(ctx0, states, n_state*(rs_zero >= 0), rs_zero*states->nb[1]*(rs_zero >= 0));
1427+
ggml_build_forward_expand(gf, ggml_scale_inplace(ctx0, state_zero, 0));
14601428

1461-
// clear states of sequences which are starting at the beginning of this batch
1462-
// FIXME: zero-out NANs?
1463-
states = ggml_mul(ctx0, states, state_mask);
1464-
1465-
// copy states which won't be changed further (between n_seqs and n_kv)
1429+
// copy extra states which won't be changed further (between n_seqs and n_kv)
1430+
ggml_tensor * states_extra = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_kv - n_seqs, n_seqs*state_copy->nb[0]));
14661431
ggml_build_forward_expand(gf,
14671432
ggml_cpy(ctx0,
1468-
ggml_view_1d(ctx0, states, n_state*(n_kv - n_seqs), (n_seqs )*n_state*ggml_element_size(states)),
1469-
ggml_view_1d(ctx0, s, n_state*(n_kv - n_seqs), (kv_head + n_seqs)*n_state*ggml_element_size(s))));
1433+
states_extra,
1434+
ggml_view_1d(ctx0, s, n_state*(n_kv - n_seqs), (kv_head + n_seqs)*n_state*ggml_element_size(s))));
1435+
1436+
if (!avoid_copies) {
1437+
// copy states
1438+
// NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv
1439+
// this shrinks the tensors's ne[1] to n_seqs
1440+
states = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_seqs, 0));
1441+
}
14701442

1471-
// the part of the states that will be used and modified
1472-
return ggml_view_2d(ctx0, states, n_state, n_seqs, states->nb[1], 0);
1443+
return states;
14731444
}
14741445

14751446
ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
14761447
ggml_cgraph * gf,
14771448
ggml_tensor * state_copy,
1478-
ggml_tensor * state_mask,
14791449
const llama_ubatch & ubatch,
14801450
int il) const {
14811451
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
@@ -1486,8 +1456,8 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
14861456

14871457
ggml_tensor * token_shift_all = kv_state->get_k_l(il);
14881458

1489-
ggml_tensor * token_shift = build_copy_mask_state(
1490-
gf, token_shift_all, state_copy, state_mask,
1459+
ggml_tensor * token_shift = build_recurrent_state(
1460+
gf, token_shift_all, state_copy,
14911461
hparams.n_embd_k_s(), n_seqs);
14921462

14931463
token_shift = ggml_reshape_3d(ctx0, token_shift, hparams.n_embd, token_shift_count, n_seqs);

src/llama-graph.h

Lines changed: 3 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -199,18 +199,6 @@ class llm_graph_input_s_copy : public llm_graph_input_i {
199199
const llama_kv_cache_recurrent_state * kv_state;
200200
};
201201

202-
class llm_graph_input_s_mask : public llm_graph_input_i {
203-
public:
204-
llm_graph_input_s_mask(const llama_kv_cache_recurrent_state * kv_state) : kv_state(kv_state) {}
205-
virtual ~llm_graph_input_s_mask() = default;
206-
207-
void set_input(const llama_ubatch * ubatch) override;
208-
209-
ggml_tensor * s_mask; // F32 [1, n_kv]
210-
211-
const llama_kv_cache_recurrent_state * kv_state;
212-
};
213-
214202
class llm_graph_input_cross_embd : public llm_graph_input_i {
215203
public:
216204
llm_graph_input_cross_embd(
@@ -520,7 +508,6 @@ struct llm_graph_context {
520508
ggml_tensor * build_inp_mean() const;
521509
ggml_tensor * build_inp_cls() const;
522510
ggml_tensor * build_inp_s_copy() const;
523-
ggml_tensor * build_inp_s_mask() const;
524511

525512
ggml_tensor * build_inp_cross_embd() const;
526513
ggml_tensor * build_inp_pos_bucket_enc() const;
@@ -605,18 +592,17 @@ struct llm_graph_context {
605592
// recurrent
606593
//
607594

608-
ggml_tensor * build_copy_mask_state(
595+
ggml_tensor * build_recurrent_state(
609596
ggml_cgraph * gf,
610597
ggml_tensor * s,
611598
ggml_tensor * state_copy,
612-
ggml_tensor * state_mask,
613599
int32_t n_state,
614-
int32_t n_seqs) const;
600+
int32_t n_seqs,
601+
bool avoid_copies = false) const;
615602

616603
ggml_tensor * build_rwkv_token_shift_load(
617604
ggml_cgraph * gf,
618605
ggml_tensor * state_copy,
619-
ggml_tensor * state_mask,
620606
const llama_ubatch & ubatch,
621607
int il) const;
622608

src/llama-kv-cache.cpp

Lines changed: 60 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -464,8 +464,6 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
464464
head_cur = 0;
465465
}
466466

467-
// otherwise, one cell per token.
468-
469467
if (n_tokens > cells.size()) {
470468
LLAMA_LOG_ERROR("%s: n_tokens = %d > size = %u\n", __func__, n_tokens, cells.size());
471469
return -1;
@@ -2344,21 +2342,12 @@ bool llama_kv_cache_recurrent::prepare(const std::vector<llama_ubatch> & ubatche
23442342

23452343
bool success = true;
23462344

2347-
// TODO: here we have to verify that all ubatches can fit in the cells
2348-
// however, the current implementation is broken because it relies on s_copy() and s_mask() to update the cells
2349-
// during the compute of each ubatch. to reproduce, uncomment the following loop and run:
2350-
//
2351-
// $ llama-parallel -m ./mamba-130m/ggml-model-f16.gguf -np 5 -ns 8
2352-
//
2353-
// recovery from failures when the batch does not fit in the KV cache will not work correctly until this is fixed
2354-
//
2355-
GGML_UNUSED(ubatches);
2356-
//for (const auto & ubatch : ubatches) {
2357-
// if (!find_slot(ubatch)) {
2358-
// success = false;
2359-
// break;
2360-
// }
2361-
//}
2345+
for (const auto & ubatch : ubatches) {
2346+
if (!find_slot(ubatch)) {
2347+
success = false;
2348+
break;
2349+
}
2350+
}
23622351

23632352
// restore the original state
23642353
cells = std::move(org_cells);
@@ -2380,14 +2369,13 @@ void llama_kv_cache_recurrent::defrag_sched(float thold) {
23802369
}
23812370

23822371
bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
2383-
const uint32_t n_tokens = ubatch.n_tokens;
2384-
const uint32_t n_seqs = ubatch.n_seqs;
2372+
const uint32_t n_seqs = ubatch.n_seqs;
23852373

23862374
const uint32_t n_seq_tokens = ubatch.n_seq_tokens;
23872375

23882376
// if we have enough unused cells before the current head ->
23892377
// better to start searching from the beginning of the cache, hoping to fill it
2390-
if (head > used + 2*n_tokens) {
2378+
if (head > used + 2*n_seqs) {
23912379
head = 0;
23922380
}
23932381

@@ -2483,16 +2471,16 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
24832471
empty_cell.src = orig_cell.src;
24842472
orig_cell.seq_id.erase(seq_id);
24852473
empty_cell.seq_id.insert(seq_id); // will be overwritten
2474+
GGML_ASSERT(!orig_cell.is_empty()); // has at least one remaining seq_id
24862475
}
24872476
seq_meta.tail = next_empty_cell;
24882477
// find next empty cell
24892478
if (s + 1 < n_seqs) {
2490-
next_empty_cell += 1;
24912479
for (uint32_t i = 0; i < size; ++i) {
2480+
next_empty_cell += 1;
24922481
if (next_empty_cell >= size) { next_empty_cell -= size; }
24932482
kv_cell & cell = cells[next_empty_cell];
24942483
if (cell.is_empty()) { break; }
2495-
next_empty_cell += 1;
24962484
}
24972485
}
24982486
}
@@ -2502,8 +2490,8 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
25022490

25032491
// gather and re-order
25042492
for (uint32_t s = 0; s < n_seqs; ++s) {
2505-
int32_t dst_id = s + min;
2506-
int32_t src_id = cells[ubatch.seq_id[s][0]].tail;
2493+
const int32_t dst_id = s + min;
2494+
const int32_t src_id = cells[ubatch.seq_id[s][0]].tail;
25072495
if (dst_id != src_id) {
25082496
kv_cell & dst_cell = cells[dst_id];
25092497
kv_cell & src_cell = cells[src_id];
@@ -2512,20 +2500,22 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
25122500
std::swap(dst_cell.src, src_cell.src);
25132501
std::swap(dst_cell.seq_id, src_cell.seq_id);
25142502

2515-
// swap tails (assuming they NEVER overlap)
2516-
for (const llama_seq_id seq_id : src_cell.seq_id) {
2517-
cells[seq_id].tail = src_id;
2518-
}
2519-
for (const llama_seq_id seq_id : dst_cell.seq_id) {
2520-
cells[seq_id].tail = dst_id;
2503+
// swap tails
2504+
for (uint32_t i = 0; i < size; ++i) {
2505+
int32_t & tail = cells[i].tail;
2506+
if (tail == src_id) {
2507+
tail = dst_id;
2508+
} else if (tail == dst_id) {
2509+
tail = src_id;
2510+
}
25212511
}
25222512
}
25232513
}
25242514

25252515
// update the pos of the used seqs
25262516
for (uint32_t s = 0; s < n_seqs; ++s) {
25272517
const llama_pos last_pos = ubatch.pos[n_seq_tokens * s + n_seq_tokens - 1];
2528-
int32_t cell_id = s + min;
2518+
const int32_t cell_id = s + min;
25292519
kv_cell & cell = cells[cell_id];
25302520

25312521
if (cell.pos >= 0 && last_pos != cell.pos + (llama_pos) n_seq_tokens) {
@@ -2543,6 +2533,38 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
25432533
}
25442534
}
25452535

2536+
// Find first cell without src refs, to use as the zero-ed state
2537+
{
2538+
// TODO: bake-in src refcounts in the cell metadata
2539+
std::vector<int32_t> refcounts(size, 0);
2540+
for (size_t i = 0; i < size; ++i) {
2541+
const int32_t src = cells[i].src;
2542+
if (src >= 0) {
2543+
refcounts[src] += 1;
2544+
}
2545+
}
2546+
2547+
rs_z = -1;
2548+
for (int i = min; i <= max; ++i) {
2549+
if (refcounts[i] == 0) {
2550+
rs_z = i;
2551+
break;
2552+
}
2553+
}
2554+
2555+
for (int i = min; i <= max; ++i) {
2556+
if (cells[i].src < 0) {
2557+
GGML_ASSERT(rs_z >= 0);
2558+
cells[i].src0 = rs_z;
2559+
} else {
2560+
// Stage the source ids for all used cells to allow correct seq_* behavior
2561+
// and still make these values available when setting the inputs
2562+
cells[i].src0 = cells[i].src;
2563+
}
2564+
cells[i].src = i; // avoid moving or clearing twice
2565+
}
2566+
}
2567+
25462568
// allow getting the range of used cells, from head to head + n
25472569
head = min;
25482570
n = max - min + 1;
@@ -2554,47 +2576,8 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
25542576
}
25552577

25562578
bool llama_kv_cache_recurrent::get_can_shift() const {
2557-
return false;
2558-
}
2559-
2560-
int32_t llama_kv_cache_recurrent::s_copy(int i) const {
2561-
const uint32_t cell_id = i + head;
2562-
2563-
//////////////////////////////////////////////
2564-
// TODO: this should not mutate the KV cache !
2565-
kv_cell & cell = const_cast<kv_cell &>(cells[cell_id]);
2566-
2567-
// prevent out-of-bound sources
2568-
if (cell.src < 0 || (uint32_t) cell.src >= size) {
2569-
cell.src = cell_id;
2570-
}
2571-
2572-
int32_t res = cell.src;
2573-
2574-
// TODO: do not mutate the KV cache
2575-
// ensure copy only happens once
2576-
if (cell.src != (int32_t) cell_id) {
2577-
cell.src = cell_id;
2578-
}
2579-
2580-
return res;
2581-
}
2582-
2583-
float llama_kv_cache_recurrent::s_mask(int i) const {
2584-
const uint32_t cell_id = i + head;
2585-
2586-
//////////////////////////////////////////////
2587-
// TODO: this should not mutate the KV cache !
2588-
kv_cell & cell = const_cast<kv_cell &>(cells[cell_id]);
2589-
2590-
float res = (float) (cell.src >= 0);
2591-
2592-
// only clear once
2593-
if (cell.src < 0) {
2594-
cell.src = cell_id;
2595-
}
2596-
2597-
return res;
2579+
// shifting the pos is trivial for recurrent models
2580+
return true;
25982581
}
25992582

26002583
size_t llama_kv_cache_recurrent::total_size() const {
@@ -3060,6 +3043,10 @@ uint32_t llama_kv_cache_recurrent_state::get_head() const {
30603043
return is_full ? 0 : kv->head;
30613044
}
30623045

3046+
int32_t llama_kv_cache_recurrent_state::get_rs_z() const {
3047+
return is_full ? 0 : kv->rs_z;
3048+
}
3049+
30633050
uint32_t llama_kv_cache_recurrent_state::get_size() const {
30643051
return kv->size;
30653052
}
@@ -3073,9 +3060,5 @@ ggml_tensor * llama_kv_cache_recurrent_state::get_v_l(int32_t il) const {
30733060
}
30743061

30753062
int32_t llama_kv_cache_recurrent_state::s_copy(int i) const {
3076-
return kv->s_copy(i);
3077-
}
3078-
3079-
float llama_kv_cache_recurrent_state::s_mask(int i) const {
3080-
return kv->s_mask(i);
3063+
return kv->cells[i + kv->head].src0;
30813064
}

0 commit comments

Comments
 (0)