Skip to content

Commit 7b50d58

Browse files
authored
kv-cells : fix tracking of seq_pos (#14339)
* kv-cells : fix tracking of seq_pos during cache reuse ggml-ci * cont : improve error message ggml-ci * cont : add more comments
1 parent 3a9457d commit 7b50d58

File tree

5 files changed

+56
-17
lines changed

5 files changed

+56
-17
lines changed

include/llama.h

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -944,12 +944,14 @@ extern "C" {
944944
// Requires the context to have a memory.
945945
// For encode-decoder contexts, processes the batch using the decoder.
946946
// Positive return values does not mean a fatal error, but rather a warning.
947-
// Upon non-zero return values, the memory state is restored to the state before this call
947+
// Upon fatal-error or abort, the ubatches that managed to be been processed will remain in the memory state of the context
948+
// To handle this correctly, query the memory state using llama_memory_seq_pos_min() and llama_memory_seq_pos_max()
949+
// Upon other return values, the memory state is restored to the state before this call
948950
// 0 - success
949951
// 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
950-
// 2 - aborted
952+
// 2 - aborted (processed ubatches will remain in the context's memory)
951953
// -1 - invalid input batch
952-
// < -1 - error
954+
// < -1 - fatal error (processed ubatches will remain in the context's memory)
953955
LLAMA_API int32_t llama_decode(
954956
struct llama_context * ctx,
955957
struct llama_batch batch);

src/llama-batch.cpp

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -245,21 +245,32 @@ bool llama_batch_allocr::init(
245245
}
246246

247247
if (memory) {
248+
bool ok = true;
249+
248250
if (batch.token) {
249251
if (seq_pos_min(s) != memory->seq_pos_max(s) + 1) {
250-
LLAMA_LOG_ERROR("%s: sequence %d does not start from the last position stored in the memory\n", __func__, s);
251-
return false;
252+
ok = false;
252253
}
253254
} else {
254255
assert(batch.embd);
255256

256257
// for embeddings (typically used as vision input), we allow them to have repeating positions
257258
// ref: https://github.com/ggml-org/llama.cpp/issues/13694#issuecomment-2983871762
258259
if (seq_pos_min(s) != memory->seq_pos_max(s) && seq_pos_min(s) != memory->seq_pos_max(s) + 1) {
259-
LLAMA_LOG_ERROR("%s: sequence %d does not start from the last position stored in the memory\n", __func__, s);
260-
return false;
260+
ok = false;
261261
}
262262
}
263+
264+
if (!ok) {
265+
LLAMA_LOG_ERROR(
266+
"%s: the tokens of sequence %d in the input batch have inconsistent sequence positions:\n"
267+
" - the last position stored in the memory module of the context (i.e. the KV cache) for sequence %d is X = %d\n"
268+
" - the tokens for sequence %d in the input batch have a starting position of Y = %d\n"
269+
" it is required that the sequence positions remain consecutive: Y = X + 1\n",
270+
__func__, s, s, memory->seq_pos_max(s), s, seq_pos_min(s));
271+
272+
return false;
273+
}
263274
}
264275

265276
if (seq_pos_max(s) - seq_pos_min(s) + 1 > (int) seq_pos[s].size()) {

src/llama-context.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1018,7 +1018,6 @@ int llama_context::decode(const llama_batch & batch_inp) {
10181018
pos_min[s] = std::numeric_limits<llama_pos>::max();
10191019
}
10201020

1021-
// TODO: fix sequence indexing
10221021
for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
10231022
const auto & seq_id = ubatch.seq_id[i][0];
10241023

src/llama-kv-cells.h

Lines changed: 33 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include <cassert>
88
#include <vector>
99
#include <set>
10+
#include <map>
1011

1112
// meta information about KV cells that can be part of multiple sequences at the same time
1213
// TODO: add unit tests
@@ -164,7 +165,7 @@ class llama_kv_cells_unified {
164165
assert(seq_id >= 0);
165166

166167
seq[i].reset(seq_id);
167-
seq_pos[seq_id].erase(pos[i]);
168+
seq_pos_dec(seq_id, pos[i]);
168169

169170
if (seq[i].none()) {
170171
pos[i] = -1;
@@ -187,7 +188,7 @@ class llama_kv_cells_unified {
187188
seq[i].reset();
188189

189190
seq[i].set(seq_id);
190-
seq_pos[seq_id].insert(pos[i]);
191+
seq_pos_inc(seq_id, pos[i]);
191192

192193
return false;
193194
}
@@ -232,7 +233,7 @@ class llama_kv_cells_unified {
232233
assert(!seq[i].test(seq_id));
233234

234235
seq[i].set(seq_id);
235-
seq_pos[seq_id].insert(pos[i]);
236+
seq_pos_inc(seq_id, pos[i]);
236237
}
237238

238239
// return the sequence id of this cell
@@ -259,7 +260,9 @@ class llama_kv_cells_unified {
259260
return -1;
260261
}
261262

262-
return *seq_pos[seq_id].begin();
263+
assert(seq_pos[seq_id].begin()->second > 0);
264+
265+
return seq_pos[seq_id].begin()->first;
263266
}
264267

265268
// the maximum position of sequence seq_id currently present in any of the cells
@@ -272,7 +275,9 @@ class llama_kv_cells_unified {
272275
return -1;
273276
}
274277

275-
return *seq_pos[seq_id].rbegin();
278+
assert(seq_pos[seq_id].rbegin()->second > 0);
279+
280+
return seq_pos[seq_id].rbegin()->first;
276281
}
277282

278283
// note: call only if the cell is not empty
@@ -389,17 +394,36 @@ class llama_kv_cells_unified {
389394
// the bitset seq[i] tells us which sequences are currently occupying the i-th cell
390395
std::vector<seq_set_t> seq;
391396

392-
// the set seq_pos[s] tells us which positions are currently present for sequence s
397+
// the set seq_pos[s][p] tells us how many times the position p is currently present for sequence s
398+
// if the position p is not present, seq_pos[s][p] is not set
393399
// this way seq_pos[s].begin() and seq_pos[s].rbegin() give us the min/max positions currently in the cache
394-
std::set<llama_pos> seq_pos[LLAMA_MAX_SEQ];
400+
//
401+
// note that we cannot a use an std::set because in some cases a position can occur more than once for the same seq:
402+
// - during performing a cache reuse via (rm + add)
403+
// - some vision models have input embeddings with repeating positions
404+
//
405+
std::map<llama_pos, int> seq_pos[LLAMA_MAX_SEQ];
395406

396407
// helper functions for updating `seq_pos`, once cell at a time:
397408

409+
void seq_pos_dec(llama_seq_id s, llama_pos p) {
410+
auto it = seq_pos[s].find(p);
411+
assert(it != seq_pos[s].end());
412+
413+
if (--it->second == 0) {
414+
seq_pos[s].erase(it);
415+
}
416+
}
417+
418+
void seq_pos_inc(llama_seq_id s, llama_pos p) {
419+
seq_pos[s][p]++;
420+
}
421+
398422
// remove cell i
399423
void seq_pos_rm(uint32_t i) {
400424
for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
401425
if (seq[i].test(s)) {
402-
seq_pos[s].erase(pos[i]);
426+
seq_pos_dec(s, pos[i]);
403427
}
404428
}
405429
}
@@ -408,7 +432,7 @@ class llama_kv_cells_unified {
408432
void seq_pos_add(uint32_t i) {
409433
for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
410434
if (seq[i].test(s)) {
411-
seq_pos[s].insert(pos[i]);
435+
seq_pos_inc(s, pos[i]);
412436
}
413437
}
414438
}

tools/server/server.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3418,9 +3418,12 @@ struct server_context {
34183418
}
34193419

34203420
if (ret < -1) {
3421+
// TODO: update slot state based on llama_memory_seq_pos_min() and llama_memory_seq_pos_max()
34213422
err = "Compute error.";
34223423
}
34233424

3425+
// TODO: handle ret == 2 (abort) when we start aborting
3426+
34243427
if (!err.empty()) {
34253428
SRV_ERR("%s, i = %d, n_batch = %d, ret = %d\n", err.c_str(), i, n_batch, ret);
34263429
for (auto & slot : slots) {

0 commit comments

Comments
 (0)