Skip to content

Commit d8430b9

Browse files
committed
kv-cache : avoid modifying recurrent cells when setting inputs
- kv-cache : remove inp_s_mask It was replaced with equivalent and simpler functionality with rs_z (the first zeroed state) and the already-existing inp_s_copy.
1 parent 2252eef commit d8430b9

File tree

5 files changed

+81
-154
lines changed

5 files changed

+81
-154
lines changed

src/llama-graph.cpp

Lines changed: 36 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -242,23 +242,23 @@ void llm_graph_input_s_copy::set_input(const llama_ubatch * ubatch) {
242242

243243
// assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
244244
for (uint32_t i = 0; i < n_kv; ++i) {
245-
data[i] = kv_self->s_copy(i);
246-
}
247-
}
248-
}
245+
const uint32_t cell_id = i + kv_self->head;
249246

250-
void llm_graph_input_s_mask::set_input(const llama_ubatch * ubatch) {
251-
GGML_UNUSED(ubatch);
247+
const auto & kv_cell = kv_self->cells[cell_id];
252248

253-
const int64_t n_kv = kv_self->n;
249+
int32_t src = kv_cell.src0;
254250

255-
if (s_mask) {
256-
GGML_ASSERT(ggml_backend_buffer_is_host(s_mask->buffer));
257-
float * data = (float *) s_mask->data;
251+
// prevent out-of-bound sources
252+
if (src < 0) {
253+
GGML_ASSERT(kv_self->rs_z >= 0); // Need a valid zero-ed cell as a source
254+
src = kv_self->rs_z;
255+
}
256+
if ((uint32_t) src >= kv_self->size) {
257+
// ignore out-of-bound sources
258+
src = cell_id;
259+
}
258260

259-
// clear unused states
260-
for (int i = 0; i < n_kv; ++i) {
261-
data[i] = kv_self->s_mask(i);
261+
data[i] = src;
262262
}
263263
}
264264
}
@@ -970,23 +970,6 @@ ggml_tensor * llm_graph_context::build_inp_s_copy() const {
970970
return cur;
971971
}
972972

973-
ggml_tensor * llm_graph_context::build_inp_s_mask() const {
974-
const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
975-
976-
auto inp = std::make_unique<llm_graph_input_s_mask>(kv_self);
977-
978-
const auto n_kv = kv_self->n;
979-
980-
auto & cur = inp->s_mask;
981-
982-
cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 1, n_kv);
983-
ggml_set_input(cur);
984-
985-
res->add_input(std::move(inp));
986-
987-
return cur;
988-
}
989-
990973
ggml_tensor * llm_graph_context::build_inp_cross_embd() const {
991974
auto inp = std::make_unique<llm_graph_input_cross_embd>(cross);
992975

@@ -1439,43 +1422,48 @@ ggml_tensor * llm_graph_context::build_attn(
14391422
return cur;
14401423
}
14411424

1442-
ggml_tensor * llm_graph_context::build_copy_mask_state(
1425+
ggml_tensor * llm_graph_context::build_recurrent_state(
14431426
ggml_cgraph * gf,
14441427
ggml_tensor * s,
14451428
ggml_tensor * state_copy,
1446-
ggml_tensor * state_mask,
14471429
int32_t n_state,
1448-
int32_t n_seqs) const {
1430+
int32_t n_seqs,
1431+
bool avoid_copies) const {
14491432
const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
14501433

14511434
const auto n_kv = kv_self->n;
14521435
const auto kv_head = kv_self->head;
1436+
const auto rs_zero = kv_self->rs_z;
14531437

14541438
ggml_tensor * states = ggml_reshape_2d(ctx0, s, n_state, kv_self->size);
14551439

1456-
// copy states
1457-
// NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv
1458-
// this shrinks the tensors's ne[1] to n_kv
1459-
states = ggml_get_rows(ctx0, states, state_copy);
1460-
1461-
// clear states of sequences which are starting at the beginning of this batch
1462-
// FIXME: zero-out NANs?
1463-
states = ggml_mul(ctx0, states, state_mask);
1440+
// Clear a single state which will then be copied to the other cleared states.
1441+
// Note that this is a no-op when the view is zero-sized.
1442+
ggml_tensor * state_zero = ggml_view_1d(ctx0, states, n_state*(rs_zero >= 0), rs_zero*states->nb[1]*(rs_zero >= 0));
1443+
ggml_build_forward_expand(gf, ggml_scale_inplace(ctx0, state_zero, 0));
14641444

14651445
// copy states which won't be changed further (between n_seqs and n_kv)
1446+
ggml_tensor * states_extra = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_kv - n_seqs, n_seqs*state_copy->nb[0]));
14661447
ggml_build_forward_expand(gf,
14671448
ggml_cpy(ctx0,
1468-
ggml_view_1d(ctx0, states, n_state*(n_kv - n_seqs), (n_seqs )*n_state*ggml_element_size(states)),
1469-
ggml_view_1d(ctx0, s, n_state*(n_kv - n_seqs), (kv_head + n_seqs)*n_state*ggml_element_size(s))));
1449+
states_extra,
1450+
ggml_view_1d(ctx0, s, n_state*(n_kv - n_seqs), (kv_head + n_seqs)*n_state*ggml_element_size(s))));
1451+
1452+
if (!avoid_copies) {
1453+
// copy states
1454+
// NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv
1455+
// this shrinks the tensors's ne[1] to n_kv
1456+
states = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_seqs, 0));
1457+
// the part of the states that will be used and modified
1458+
states = ggml_view_2d(ctx0, states, n_state, n_seqs, states->nb[1], 0);
1459+
}
14701460

1471-
// the part of the states that will be used and modified
1472-
return ggml_view_2d(ctx0, states, n_state, n_seqs, states->nb[1], 0);
1461+
return states;
14731462
}
14741463

14751464
ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
14761465
ggml_cgraph * gf,
14771466
ggml_tensor * state_copy,
1478-
ggml_tensor * state_mask,
14791467
const llama_ubatch & ubatch,
14801468
int il) const {
14811469
const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
@@ -1486,8 +1474,8 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
14861474

14871475
ggml_tensor * token_shift_all = kv_self->k_l[il];
14881476

1489-
ggml_tensor * token_shift = build_copy_mask_state(
1490-
gf, token_shift_all, state_copy, state_mask,
1477+
ggml_tensor * token_shift = build_recurrent_state(
1478+
gf, token_shift_all, state_copy,
14911479
hparams.n_embd_k_s(), n_seqs);
14921480

14931481
token_shift = ggml_reshape_3d(ctx0, token_shift, hparams.n_embd, token_shift_count, n_seqs);

src/llama-graph.h

Lines changed: 3 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -198,18 +198,6 @@ class llm_graph_input_s_copy : public llm_graph_input_i {
198198
const llama_kv_cache_recurrent * kv_self;
199199
};
200200

201-
class llm_graph_input_s_mask : public llm_graph_input_i {
202-
public:
203-
llm_graph_input_s_mask(const llama_kv_cache_recurrent * kv_self) : kv_self(kv_self) {}
204-
virtual ~llm_graph_input_s_mask() = default;
205-
206-
void set_input(const llama_ubatch * ubatch) override;
207-
208-
ggml_tensor * s_mask; // F32 [1, n_kv]
209-
210-
const llama_kv_cache_recurrent * kv_self;
211-
};
212-
213201
class llm_graph_input_cross_embd : public llm_graph_input_i {
214202
public:
215203
llm_graph_input_cross_embd(
@@ -519,7 +507,6 @@ struct llm_graph_context {
519507
ggml_tensor * build_inp_mean() const;
520508
ggml_tensor * build_inp_cls() const;
521509
ggml_tensor * build_inp_s_copy() const;
522-
ggml_tensor * build_inp_s_mask() const;
523510

524511
ggml_tensor * build_inp_cross_embd() const;
525512
ggml_tensor * build_inp_pos_bucket_enc() const;
@@ -604,18 +591,17 @@ struct llm_graph_context {
604591
// recurrent
605592
//
606593

607-
ggml_tensor * build_copy_mask_state(
594+
ggml_tensor * build_recurrent_state(
608595
ggml_cgraph * gf,
609596
ggml_tensor * s,
610597
ggml_tensor * state_copy,
611-
ggml_tensor * state_mask,
612598
int32_t n_state,
613-
int32_t n_seqs) const;
599+
int32_t n_seqs,
600+
bool avoid_copies = false) const;
614601

615602
ggml_tensor * build_rwkv_token_shift_load(
616603
ggml_cgraph * gf,
617604
ggml_tensor * state_copy,
618-
ggml_tensor * state_mask,
619605
const llama_ubatch & ubatch,
620606
int il) const;
621607

src/llama-kv-cache.cpp

Lines changed: 20 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -529,8 +529,6 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
529529
head_cur = 0;
530530
}
531531

532-
// otherwise, one cell per token.
533-
534532
if (n_tokens > cells.size()) {
535533
LLAMA_LOG_ERROR("%s: n_tokens = %d > size = %u\n", __func__, n_tokens, cells.size());
536534
return -1;
@@ -2310,21 +2308,12 @@ bool llama_kv_cache_recurrent::prepare(const std::vector<llama_ubatch> & ubatche
23102308

23112309
bool success = true;
23122310

2313-
// TODO: here we have to verify that all ubatches can fit in the cells
2314-
// however, the current implementation is broken because it relies on s_copy() and s_mask() to update the cells
2315-
// during the compute of each ubatch. to reproduce, uncomment the following loop and run:
2316-
//
2317-
// $ llama-parallel -m ./mamba-130m/ggml-model-f16.gguf -np 5 -ns 8
2318-
//
2319-
// recovery from failures when the batch does not fit in the KV cache will not work correctly until this is fixed
2320-
//
2321-
GGML_UNUSED(ubatches);
2322-
//for (const auto & ubatch : ubatches) {
2323-
// if (!find_slot(ubatch)) {
2324-
// success = false;
2325-
// break;
2326-
// }
2327-
//}
2311+
for (const auto & ubatch : ubatches) {
2312+
if (!find_slot(ubatch)) {
2313+
success = false;
2314+
break;
2315+
}
2316+
}
23282317

23292318
// restore the original state
23302319
cells = std::move(org_cells);
@@ -2514,6 +2503,18 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
25142503
}
25152504
}
25162505

2506+
// Find first to-be-cleared cell
2507+
rs_z = -1;
2508+
for (int i = min; i <= max; ++i) {
2509+
if (rs_z < 0 && cells[i].src == -1) {
2510+
rs_z = i;
2511+
}
2512+
// Stage the source ids for all used cells to allow correct seq_* behavior
2513+
// and still make these values available when setting the inputs
2514+
cells[i].src0 = cells[i].src;
2515+
cells[i].src = i;
2516+
}
2517+
25172518
// allow getting the range of used cells, from head to head + n
25182519
head = min;
25192520
n = max - min + 1;
@@ -2525,47 +2526,8 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
25252526
}
25262527

25272528
bool llama_kv_cache_recurrent::get_can_shift() const {
2528-
return false;
2529-
}
2530-
2531-
int32_t llama_kv_cache_recurrent::s_copy(int i) const {
2532-
const uint32_t cell_id = i + head;
2533-
2534-
//////////////////////////////////////////////
2535-
// TODO: this should not mutate the KV cache !
2536-
kv_cell & cell = const_cast<kv_cell &>(cells[cell_id]);
2537-
2538-
// prevent out-of-bound sources
2539-
if (cell.src < 0 || (uint32_t) cell.src >= size) {
2540-
cell.src = cell_id;
2541-
}
2542-
2543-
int32_t res = cell.src;
2544-
2545-
// TODO: do not mutate the KV cache
2546-
// ensure copy only happens once
2547-
if (cell.src != (int32_t) cell_id) {
2548-
cell.src = cell_id;
2549-
}
2550-
2551-
return res;
2552-
}
2553-
2554-
float llama_kv_cache_recurrent::s_mask(int i) const {
2555-
const uint32_t cell_id = i + head;
2556-
2557-
//////////////////////////////////////////////
2558-
// TODO: this should not mutate the KV cache !
2559-
kv_cell & cell = const_cast<kv_cell &>(cells[cell_id]);
2560-
2561-
float res = (float) (cell.src >= 0);
2562-
2563-
// only clear once
2564-
if (cell.src < 0) {
2565-
cell.src = cell_id;
2566-
}
2567-
2568-
return res;
2529+
// shifting the pos is trivial for recurrent models
2530+
return true;
25692531
}
25702532

25712533
size_t llama_kv_cache_recurrent::total_size() const {

src/llama-kv-cache.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -362,10 +362,6 @@ class llama_kv_cache_recurrent : public llama_kv_cache {
362362

363363
bool get_can_shift() const override;
364364

365-
// TODO: temporary methods - they are not really const as they do const_cast<>, fix this
366-
int32_t s_copy(int i) const;
367-
float s_mask(int i) const;
368-
369365
// state write/load
370366

371367
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
@@ -378,10 +374,14 @@ class llama_kv_cache_recurrent : public llama_kv_cache {
378374
// computed before each graph build
379375
uint32_t n = 0;
380376

377+
// first zero-ed state
378+
int32_t rs_z = -1;
379+
381380
// TODO: optimize for recurrent state needs
382381
struct kv_cell {
383382
llama_pos pos = -1;
384-
int32_t src = -1; // used to copy states
383+
int32_t src = -1; // used to know where states should be copied from
384+
int32_t src0 = -1; // like src, but only used when setting the inputs (allowing to copy once)
385385
int32_t tail = -1;
386386

387387
std::set<llama_seq_id> seq_id;

0 commit comments

Comments
 (0)