Skip to content

kv-cache : avoid modifying recurrent cells when setting inputs #13834

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jun 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 30 additions & 53 deletions src/llama-graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -250,22 +250,6 @@ void llm_graph_input_s_copy::set_input(const llama_ubatch * ubatch) {
}
}

void llm_graph_input_s_mask::set_input(const llama_ubatch * ubatch) {
GGML_UNUSED(ubatch);

const int64_t n_kv = kv_state->get_n_kv();

if (s_mask) {
GGML_ASSERT(ggml_backend_buffer_is_host(s_mask->buffer));
float * data = (float *) s_mask->data;

// clear unused states
for (int i = 0; i < n_kv; ++i) {
data[i] = kv_state->s_mask(i);
}
}
}

void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) {
GGML_UNUSED(ubatch);

Expand Down Expand Up @@ -986,23 +970,6 @@ ggml_tensor * llm_graph_context::build_inp_s_copy() const {
return cur;
}

ggml_tensor * llm_graph_context::build_inp_s_mask() const {
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);

auto inp = std::make_unique<llm_graph_input_s_mask>(kv_state);

const auto n_kv = kv_state->get_n_kv();

auto & cur = inp->s_mask;

cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 1, n_kv);
ggml_set_input(cur);

res->add_input(std::move(inp));

return cur;
}

ggml_tensor * llm_graph_context::build_inp_cross_embd() const {
auto inp = std::make_unique<llm_graph_input_cross_embd>(cross);

Expand Down Expand Up @@ -1455,43 +1422,53 @@ ggml_tensor * llm_graph_context::build_attn(
return cur;
}

ggml_tensor * llm_graph_context::build_copy_mask_state(
ggml_tensor * llm_graph_context::build_recurrent_state(
ggml_cgraph * gf,
ggml_tensor * s,
ggml_tensor * state_copy,
ggml_tensor * state_mask,
int32_t n_state,
int32_t n_seqs) const {
int32_t state_size,
int32_t n_seqs,
bool avoid_copies) const {
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);

const auto n_kv = kv_state->get_n_kv();
const auto kv_head = kv_state->get_head();
const auto rs_zero = kv_state->get_rs_z();

ggml_tensor * states = ggml_reshape_2d(ctx0, s, n_state, kv_state->get_size());
ggml_tensor * states = ggml_reshape_2d(ctx0, s, state_size, kv_state->get_size());

// copy states
// NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv
// this shrinks the tensors's ne[1] to n_kv
states = ggml_get_rows(ctx0, states, state_copy);
// Clear a single state which will then be copied to the other cleared states.
// Note that this is a no-op when the view is zero-sized.
ggml_tensor * state_zero = ggml_view_1d(ctx0, states, state_size*(rs_zero >= 0), rs_zero*states->nb[1]*(rs_zero >= 0));
ggml_build_forward_expand(gf, ggml_scale_inplace(ctx0, state_zero, 0));

// clear states of sequences which are starting at the beginning of this batch
// FIXME: zero-out NANs?
states = ggml_mul(ctx0, states, state_mask);
ggml_tensor * output_states;

if (!avoid_copies) {
// copy states
// NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv
// {state_size, kv_size} -> {state_size, n_seqs}
output_states = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_seqs, 0));
ggml_build_forward_expand(gf, output_states);
} else {
// FIXME: make the gathering operation happen before the copy below
// (maybe with an optional lambda function passed as a parameter instead of `avoid_copies`?)
output_states = states;
}

// copy states which won't be changed further (between n_seqs and n_kv)
// copy extra states which won't be changed further (between n_seqs and n_kv)
ggml_tensor * states_extra = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_kv - n_seqs, n_seqs*state_copy->nb[0]));
ggml_build_forward_expand(gf,
ggml_cpy(ctx0,
ggml_view_1d(ctx0, states, n_state*(n_kv - n_seqs), (n_seqs )*n_state*ggml_element_size(states)),
ggml_view_1d(ctx0, s, n_state*(n_kv - n_seqs), (kv_head + n_seqs)*n_state*ggml_element_size(s))));
states_extra,
ggml_view_1d(ctx0, s, state_size*(n_kv - n_seqs), (kv_head + n_seqs)*state_size*ggml_element_size(s))));

// the part of the states that will be used and modified
return ggml_view_2d(ctx0, states, n_state, n_seqs, states->nb[1], 0);
return output_states;
}

ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
ggml_cgraph * gf,
ggml_tensor * state_copy,
ggml_tensor * state_mask,
const llama_ubatch & ubatch,
int il) const {
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
Expand All @@ -1502,8 +1479,8 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(

ggml_tensor * token_shift_all = kv_state->get_k_l(il);

ggml_tensor * token_shift = build_copy_mask_state(
gf, token_shift_all, state_copy, state_mask,
ggml_tensor * token_shift = build_recurrent_state(
gf, token_shift_all, state_copy,
hparams.n_embd_k_s(), n_seqs);

token_shift = ggml_reshape_3d(ctx0, token_shift, hparams.n_embd, token_shift_count, n_seqs);
Expand Down
22 changes: 4 additions & 18 deletions src/llama-graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -200,18 +200,6 @@ class llm_graph_input_s_copy : public llm_graph_input_i {
const llama_kv_cache_recurrent_state * kv_state;
};

class llm_graph_input_s_mask : public llm_graph_input_i {
public:
llm_graph_input_s_mask(const llama_kv_cache_recurrent_state * kv_state) : kv_state(kv_state) {}
virtual ~llm_graph_input_s_mask() = default;

void set_input(const llama_ubatch * ubatch) override;

ggml_tensor * s_mask; // F32 [1, n_kv]

const llama_kv_cache_recurrent_state * kv_state;
};

class llm_graph_input_cross_embd : public llm_graph_input_i {
public:
llm_graph_input_cross_embd(
Expand Down Expand Up @@ -521,7 +509,6 @@ struct llm_graph_context {
ggml_tensor * build_inp_mean() const;
ggml_tensor * build_inp_cls() const;
ggml_tensor * build_inp_s_copy() const;
ggml_tensor * build_inp_s_mask() const;

ggml_tensor * build_inp_cross_embd() const;
ggml_tensor * build_inp_pos_bucket_enc() const;
Expand Down Expand Up @@ -606,18 +593,17 @@ struct llm_graph_context {
// recurrent
//

ggml_tensor * build_copy_mask_state(
ggml_tensor * build_recurrent_state(
ggml_cgraph * gf,
ggml_tensor * s,
ggml_tensor * state_copy,
ggml_tensor * state_mask,
int32_t n_state,
int32_t n_seqs) const;
int32_t state_size,
int32_t n_seqs,
bool avoid_copies = false) const;

ggml_tensor * build_rwkv_token_shift_load(
ggml_cgraph * gf,
ggml_tensor * state_copy,
ggml_tensor * state_mask,
const llama_ubatch & ubatch,
int il) const;

Expand Down
135 changes: 60 additions & 75 deletions src/llama-kv-cache-recurrent.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -406,21 +406,12 @@ bool llama_kv_cache_recurrent::prepare(const std::vector<llama_ubatch> & ubatche

bool success = true;

// TODO: here we have to verify that all ubatches can fit in the cells
// however, the current implementation is broken because it relies on s_copy() and s_mask() to update the cells
// during the compute of each ubatch. to reproduce, uncomment the following loop and run:
//
// $ llama-parallel -m ./mamba-130m/ggml-model-f16.gguf -np 5 -ns 8
//
// recovery from failures when the batch does not fit in the KV cache will not work correctly until this is fixed
//
GGML_UNUSED(ubatches);
//for (const auto & ubatch : ubatches) {
// if (!find_slot(ubatch)) {
// success = false;
// break;
// }
//}
for (const auto & ubatch : ubatches) {
if (!find_slot(ubatch)) {
success = false;
break;
}
}

// restore the original state
cells = std::move(org_cells);
Expand All @@ -431,14 +422,13 @@ bool llama_kv_cache_recurrent::prepare(const std::vector<llama_ubatch> & ubatche
}

bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
const uint32_t n_tokens = ubatch.n_tokens;
const uint32_t n_seqs = ubatch.n_seqs;
const uint32_t n_seqs = ubatch.n_seqs;

const uint32_t n_seq_tokens = ubatch.n_seq_tokens;

// if we have enough unused cells before the current head ->
// better to start searching from the beginning of the cache, hoping to fill it
if (head > used + 2*n_tokens) {
if (head > used + 2*n_seqs) {
head = 0;
}

Expand Down Expand Up @@ -534,16 +524,16 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
empty_cell.src = orig_cell.src;
orig_cell.seq_id.erase(seq_id);
empty_cell.seq_id.insert(seq_id); // will be overwritten
GGML_ASSERT(!orig_cell.is_empty()); // has at least one remaining seq_id
}
seq_meta.tail = next_empty_cell;
// find next empty cell
if (s + 1 < n_seqs) {
next_empty_cell += 1;
for (uint32_t i = 0; i < size; ++i) {
next_empty_cell += 1;
if (next_empty_cell >= size) { next_empty_cell -= size; }
kv_cell & cell = cells[next_empty_cell];
if (cell.is_empty()) { break; }
next_empty_cell += 1;
}
}
}
Expand All @@ -553,8 +543,8 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {

// gather and re-order
for (uint32_t s = 0; s < n_seqs; ++s) {
int32_t dst_id = s + min;
int32_t src_id = cells[ubatch.seq_id[s][0]].tail;
const int32_t dst_id = s + min;
const int32_t src_id = cells[ubatch.seq_id[s][0]].tail;
if (dst_id != src_id) {
kv_cell & dst_cell = cells[dst_id];
kv_cell & src_cell = cells[src_id];
Expand All @@ -563,20 +553,22 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
std::swap(dst_cell.src, src_cell.src);
std::swap(dst_cell.seq_id, src_cell.seq_id);

// swap tails (assuming they NEVER overlap)
for (const llama_seq_id seq_id : src_cell.seq_id) {
cells[seq_id].tail = src_id;
}
for (const llama_seq_id seq_id : dst_cell.seq_id) {
cells[seq_id].tail = dst_id;
// swap tails
for (uint32_t i = 0; i < size; ++i) {
int32_t & tail = cells[i].tail;
if (tail == src_id) {
tail = dst_id;
} else if (tail == dst_id) {
tail = src_id;
}
}
}
}

// update the pos of the used seqs
for (uint32_t s = 0; s < n_seqs; ++s) {
const llama_pos last_pos = ubatch.pos[n_seq_tokens * s + n_seq_tokens - 1];
int32_t cell_id = s + min;
const int32_t cell_id = s + min;
kv_cell & cell = cells[cell_id];

if (cell.pos >= 0 && last_pos != cell.pos + (llama_pos) n_seq_tokens) {
Expand All @@ -594,6 +586,38 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
}
}

// Find first cell without src refs, to use as the zero-ed state
{
// TODO: bake-in src refcounts in the cell metadata
std::vector<int32_t> refcounts(size, 0);
for (size_t i = 0; i < size; ++i) {
const int32_t src = cells[i].src;
if (src >= 0) {
refcounts[src] += 1;
}
}

rs_z = -1;
for (int i = min; i <= max; ++i) {
if (refcounts[i] == 0) {
rs_z = i;
break;
}
}

for (int i = min; i <= max; ++i) {
if (cells[i].src < 0) {
GGML_ASSERT(rs_z >= 0);
cells[i].src0 = rs_z;
} else {
// Stage the source ids for all used cells to allow correct seq_* behavior
// and still make these values available when setting the inputs
cells[i].src0 = cells[i].src;
}
cells[i].src = i; // avoid moving or clearing twice
}
}

// allow getting the range of used cells, from head to head + n
head = min;
n = max - min + 1;
Expand All @@ -605,47 +629,8 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
}

bool llama_kv_cache_recurrent::get_can_shift() const {
return false;
}

int32_t llama_kv_cache_recurrent::s_copy(int i) const {
const uint32_t cell_id = i + head;

//////////////////////////////////////////////
// TODO: this should not mutate the KV cache !
kv_cell & cell = const_cast<kv_cell &>(cells[cell_id]);

// prevent out-of-bound sources
if (cell.src < 0 || (uint32_t) cell.src >= size) {
cell.src = cell_id;
}

int32_t res = cell.src;

// TODO: do not mutate the KV cache
// ensure copy only happens once
if (cell.src != (int32_t) cell_id) {
cell.src = cell_id;
}

return res;
}

float llama_kv_cache_recurrent::s_mask(int i) const {
const uint32_t cell_id = i + head;

//////////////////////////////////////////////
// TODO: this should not mutate the KV cache !
kv_cell & cell = const_cast<kv_cell &>(cells[cell_id]);

float res = (float) (cell.src >= 0);

// only clear once
if (cell.src < 0) {
cell.src = cell_id;
}

return res;
// shifting the pos is trivial for recurrent models
return true;
}

size_t llama_kv_cache_recurrent::total_size() const {
Expand Down Expand Up @@ -1111,6 +1096,10 @@ uint32_t llama_kv_cache_recurrent_state::get_head() const {
return is_full ? 0 : kv->head;
}

int32_t llama_kv_cache_recurrent_state::get_rs_z() const {
return is_full ? 0 : kv->rs_z;
}

uint32_t llama_kv_cache_recurrent_state::get_size() const {
return kv->size;
}
Expand All @@ -1124,9 +1113,5 @@ ggml_tensor * llama_kv_cache_recurrent_state::get_v_l(int32_t il) const {
}

int32_t llama_kv_cache_recurrent_state::s_copy(int i) const {
return kv->s_copy(i);
}

float llama_kv_cache_recurrent_state::s_mask(int i) const {
return kv->s_mask(i);
return kv->cells[i + kv->head].src0;
}
Loading
Loading