Skip to content

Commit bdbfb4e

Browse files
committed
kv-cache : fix non-consecutive token pos warning for recurrent models
The problem was apparently caused by how the tail cells were swapped. * graph : simplify logic for recurrent state copies
1 parent d8430b9 commit bdbfb4e

File tree

2 files changed

+27
-38
lines changed

2 files changed

+27
-38
lines changed

src/llama-graph.cpp

Lines changed: 3 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -242,23 +242,7 @@ void llm_graph_input_s_copy::set_input(const llama_ubatch * ubatch) {
242242

243243
// assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
244244
for (uint32_t i = 0; i < n_kv; ++i) {
245-
const uint32_t cell_id = i + kv_self->head;
246-
247-
const auto & kv_cell = kv_self->cells[cell_id];
248-
249-
int32_t src = kv_cell.src0;
250-
251-
// prevent out-of-bound sources
252-
if (src < 0) {
253-
GGML_ASSERT(kv_self->rs_z >= 0); // Need a valid zero-ed cell as a source
254-
src = kv_self->rs_z;
255-
}
256-
if ((uint32_t) src >= kv_self->size) {
257-
// ignore out-of-bound sources
258-
src = cell_id;
259-
}
260-
261-
data[i] = src;
245+
data[i] = kv_self->cells[i + kv_self->head].src0;
262246
}
263247
}
264248
}
@@ -1442,7 +1426,7 @@ ggml_tensor * llm_graph_context::build_recurrent_state(
14421426
ggml_tensor * state_zero = ggml_view_1d(ctx0, states, n_state*(rs_zero >= 0), rs_zero*states->nb[1]*(rs_zero >= 0));
14431427
ggml_build_forward_expand(gf, ggml_scale_inplace(ctx0, state_zero, 0));
14441428

1445-
// copy states which won't be changed further (between n_seqs and n_kv)
1429+
// copy extra states which won't be changed further (between n_seqs and n_kv)
14461430
ggml_tensor * states_extra = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_kv - n_seqs, n_seqs*state_copy->nb[0]));
14471431
ggml_build_forward_expand(gf,
14481432
ggml_cpy(ctx0,
@@ -1452,10 +1436,8 @@ ggml_tensor * llm_graph_context::build_recurrent_state(
14521436
if (!avoid_copies) {
14531437
// copy states
14541438
// NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv
1455-
// this shrinks the tensors's ne[1] to n_kv
1439+
// this shrinks the tensors's ne[1] to n_seqs
14561440
states = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_seqs, 0));
1457-
// the part of the states that will be used and modified
1458-
states = ggml_view_2d(ctx0, states, n_state, n_seqs, states->nb[1], 0);
14591441
}
14601442

14611443
return states;

src/llama-kv-cache.cpp

Lines changed: 24 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2337,17 +2337,17 @@ void llama_kv_cache_recurrent::defrag_sched(float thold) {
23372337
void llama_kv_cache_recurrent::set_full() {
23382338
n = size;
23392339
head = 0;
2340+
rs_z = 0;
23402341
}
23412342

23422343
bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
2343-
const uint32_t n_tokens = ubatch.n_tokens;
2344-
const uint32_t n_seqs = ubatch.n_seqs;
2344+
const uint32_t n_seqs = ubatch.n_seqs;
23452345

23462346
const uint32_t n_seq_tokens = ubatch.n_seq_tokens;
23472347

23482348
// if we have enough unused cells before the current head ->
23492349
// better to start searching from the beginning of the cache, hoping to fill it
2350-
if (head > used + 2*n_tokens) {
2350+
if (head > used + 2*n_seqs) {
23512351
head = 0;
23522352
}
23532353

@@ -2443,16 +2443,16 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
24432443
empty_cell.src = orig_cell.src;
24442444
orig_cell.seq_id.erase(seq_id);
24452445
empty_cell.seq_id.insert(seq_id); // will be overwritten
2446+
GGML_ASSERT(!orig_cell.is_empty()); // has at least one remaining seq_id
24462447
}
24472448
seq_meta.tail = next_empty_cell;
24482449
// find next empty cell
24492450
if (s + 1 < n_seqs) {
2450-
next_empty_cell += 1;
24512451
for (uint32_t i = 0; i < size; ++i) {
2452+
next_empty_cell += 1;
24522453
if (next_empty_cell >= size) { next_empty_cell -= size; }
24532454
kv_cell & cell = cells[next_empty_cell];
24542455
if (cell.is_empty()) { break; }
2455-
next_empty_cell += 1;
24562456
}
24572457
}
24582458
}
@@ -2472,12 +2472,14 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
24722472
std::swap(dst_cell.src, src_cell.src);
24732473
std::swap(dst_cell.seq_id, src_cell.seq_id);
24742474

2475-
// swap tails (assuming they NEVER overlap)
2476-
for (const llama_seq_id seq_id : src_cell.seq_id) {
2477-
cells[seq_id].tail = src_id;
2478-
}
2479-
for (const llama_seq_id seq_id : dst_cell.seq_id) {
2480-
cells[seq_id].tail = dst_id;
2475+
// swap tails
2476+
for (uint32_t i = 0; i < size; ++i) {
2477+
int32_t & tail = cells[i].tail;
2478+
if (tail == src_id) {
2479+
tail = dst_id;
2480+
} else if (tail == dst_id) {
2481+
tail = src_id;
2482+
}
24812483
}
24822484
}
24832485
}
@@ -2506,13 +2508,18 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
25062508
// Find first to-be-cleared cell
25072509
rs_z = -1;
25082510
for (int i = min; i <= max; ++i) {
2509-
if (rs_z < 0 && cells[i].src == -1) {
2510-
rs_z = i;
2511+
if (cells[i].src == -1) {
2512+
if (rs_z < 0) {
2513+
rs_z = i;
2514+
}
2515+
2516+
cells[i].src0 = rs_z;
2517+
} else {
2518+
// Stage the source ids for all used cells to allow correct seq_* behavior
2519+
// and still make these values available when setting the inputs
2520+
cells[i].src0 = cells[i].src;
25112521
}
2512-
// Stage the source ids for all used cells to allow correct seq_* behavior
2513-
// and still make these values available when setting the inputs
2514-
cells[i].src0 = cells[i].src;
2515-
cells[i].src = i;
2522+
cells[i].src = i; // avoid moving or clearing twice
25162523
}
25172524

25182525
// allow getting the range of used cells, from head to head + n

0 commit comments

Comments
 (0)