Skip to content

Commit 40cbf57

Browse files
authored
kv-cache : fix shift and defrag logic (#14081)
* kv-cache : fix shift ggml-ci * cont : reset shift[i] ggml-ci * cont : fix defrag erasing cells that didn't move ggml-ci
1 parent 7f4fbe5 commit 40cbf57

File tree

2 files changed

+12
-9
lines changed

2 files changed

+12
-9
lines changed

src/llama-kv-cache-unified.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -462,7 +462,7 @@ bool llama_kv_cache_unified::update(llama_context * lctx, bool do_shift, const d
462462
for (uint32_t i = 0; i < n_kv; ++i) {
463463
assert(dinfo.ids[i] <= n_kv);
464464

465-
if (dinfo.ids[i] == n_kv) {
465+
if (dinfo.ids[i] == n_kv || dinfo.ids[i] == i) {
466466
continue;
467467
}
468468

@@ -944,11 +944,9 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift(
944944
const auto & n_embd_head_k = hparams.n_embd_head_k;
945945
//const auto & n_embd_head_v = hparams.n_embd_head_v;
946946

947-
//GGML_ASSERT(kv_self->size == n_ctx);
948-
949947
auto inp = std::make_unique<llm_graph_input_k_shift>(this);
950948

951-
inp->k_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, cparams.n_ctx);
949+
inp->k_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, cells.size());
952950
ggml_set_input(inp->k_shift);
953951

954952
for (const auto & layer : layers) {

src/llama-kv-cells.h

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,9 @@ class llama_kv_cells_unified {
8080
assert(isrc < pos.size());
8181
assert(idst < pos.size());
8282

83+
assert(pos[idst] == -1);
84+
assert(pos[isrc] != -1);
85+
8386
pos [idst] = pos [isrc];
8487
shift[idst] = shift[isrc];
8588
seq [idst] = seq [isrc];
@@ -144,9 +147,10 @@ class llama_kv_cells_unified {
144147
assert(pos[i] != -1);
145148

146149
seq_pos_rm(i);
150+
seq[i].reset();
147151

148152
pos[i] = -1;
149-
seq[i].reset();
153+
shift[i] = 0;
150154

151155
used.erase(i);
152156
}
@@ -164,6 +168,7 @@ class llama_kv_cells_unified {
164168

165169
if (seq[i].none()) {
166170
pos[i] = -1;
171+
shift[i] = 0;
167172

168173
used.erase(i);
169174

@@ -192,6 +197,7 @@ class llama_kv_cells_unified {
192197
seq[i].reset();
193198

194199
pos[i] = -1;
200+
shift[i] = 0;
195201

196202
used.erase(i);
197203

@@ -317,21 +323,20 @@ class llama_kv_cells_unified {
317323
pos[i] += d;
318324
shift[i] += d;
319325

320-
seq_pos_add(i);
321-
322326
has_shift = true;
323327

324328
if (pos[i] < 0) {
325-
seq_pos_rm(i);
326-
327329
seq[i].reset();
328330
pos[i] = -1;
331+
shift[i] = 0;
329332

330333
used.erase(i);
331334

332335
return true;
333336
}
334337

338+
seq_pos_add(i);
339+
335340
return false;
336341
}
337342

0 commit comments

Comments
 (0)