Skip to content

Commit 0026c81

Browse files
committed
llama: restore a kv_cache in case of failed computation
1 parent acb9528 commit 0026c81

File tree

1 file changed

+65
-12
lines changed

1 file changed

+65
-12
lines changed

src/llama.cpp

Lines changed: 65 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2811,6 +2811,22 @@ struct llama_kv_cache {
28112811
}
28122812
};
28132813

2814+
// saves the kv_cache state for future recovery
2815+
// used to preserve the kv_cache state before searching for a slot
2816+
struct llama_kv_slot_restorer {
2817+
struct llama_kv_cache_state {
2818+
uint32_t head = 0;
2819+
uint32_t size = 0;
2820+
uint32_t used = 0;
2821+
uint32_t n = 0;
2822+
} old_state;
2823+
2824+
std::vector<llama_kv_cell> recurrent_cells; // for recurrent models only
2825+
std::pair<uint32_t, uint32_t> slot_boundaries; // for non-recurrent models only
2826+
2827+
bool restore = false;
2828+
};
2829+
28142830
struct llama_control_vector {
28152831
std::vector<struct ggml_tensor *> tensors; // per layer
28162832
std::vector<ggml_context_ptr> ctxs;
@@ -3508,11 +3524,19 @@ static bool llama_kv_cache_init(
35083524
// to the first cell of the slot.
35093525
static bool llama_kv_cache_find_slot(
35103526
struct llama_kv_cache & cache,
3511-
const struct llama_ubatch & batch) {
3527+
const struct llama_ubatch & batch,
3528+
struct llama_kv_slot_restorer * slot_restorer = nullptr) {
35123529
const uint32_t n_tokens = batch.n_tokens;
35133530
const uint32_t n_seqs = batch.n_seqs;
35143531
const uint32_t n_seq_tokens = batch.n_seq_tokens;
35153532

3533+
if (slot_restorer != nullptr) {
3534+
slot_restorer->old_state.head = cache.head;
3535+
slot_restorer->old_state.size = cache.size;
3536+
slot_restorer->old_state.used = cache.used;
3537+
slot_restorer->old_state.n = cache.n;
3538+
}
3539+
35163540
if (cache.recurrent) {
35173541
// For recurrent state architectures (like Mamba or RWKV),
35183542
// each cache cell can store the state for a whole sequence.
@@ -3521,6 +3545,11 @@ static bool llama_kv_cache_find_slot(
35213545
// can only process batches with an equal number of new tokens in each sequence
35223546
GGML_ASSERT(batch.equal_seqs);
35233547

3548+
if (slot_restorer != nullptr) {
3549+
slot_restorer->recurrent_cells = cache.cells;
3550+
slot_restorer->restore = true;
3551+
}
3552+
35243553
int32_t min = cache.size - 1;
35253554
int32_t max = 0;
35263555

@@ -3709,6 +3738,11 @@ static bool llama_kv_cache_find_slot(
37093738
}
37103739
}
37113740

3741+
if (slot_restorer != nullptr) {
3742+
slot_restorer->slot_boundaries = std::make_pair(cache.head, cache.head + n_tokens);
3743+
slot_restorer->restore = true;
3744+
}
3745+
37123746
for (uint32_t s = 0; s < n_seqs; s++) {
37133747
for (uint32_t i = 0; i < n_seq_tokens; ++i) {
37143748
uint32_t k = s*n_seq_tokens + i;
@@ -3998,6 +4032,23 @@ static uint32_t llama_kv_cache_get_padding(const struct llama_cparams & cparams)
39984032
return cparams.flash_attn ? 256u : 32u;
39994033
}
40004034

4035+
static void llama_kv_cache_slot_restore(
4036+
const struct llama_kv_slot_restorer & restorer,
4037+
struct llama_kv_cache & cache) {
4038+
if (restorer.restore) {
4039+
cache.head = restorer.old_state.head;
4040+
cache.size = restorer.old_state.size;
4041+
cache.used = restorer.old_state.used;
4042+
cache.n = restorer.old_state.n;
4043+
4044+
if (cache.recurrent) {
4045+
cache.cells = restorer.recurrent_cells;
4046+
} else {
4047+
llama_kv_cache_seq_rm(cache, -1, restorer.slot_boundaries.first, restorer.slot_boundaries.second + 1);
4048+
}
4049+
}
4050+
}
4051+
40014052
//
40024053
// model loading and saving
40034054
//
@@ -17256,6 +17307,7 @@ static int llama_decode_internal(
1725617307
lctx.n_queued_tokens += n_tokens_all;
1725717308

1725817309
auto & kv_self = lctx.kv_self;
17310+
llama_kv_slot_restorer kv_slot_restorer;
1725917311

1726017312
const int64_t n_embd = hparams.n_embd;
1726117313
const int64_t n_vocab = hparams.n_vocab;
@@ -17340,7 +17392,7 @@ static int llama_decode_internal(
1734017392
kv_self.head = 0;
1734117393
}
1734217394

17343-
if (!llama_kv_cache_find_slot(kv_self, ubatch)) {
17395+
if (!llama_kv_cache_find_slot(kv_self, ubatch, &kv_slot_restorer)) {
1734417396
return 1;
1734517397
}
1734617398

@@ -17390,16 +17442,17 @@ static int llama_decode_internal(
1739017442
llama_set_inputs(lctx, ubatch);
1739117443

1739217444
const auto compute_status = llama_graph_compute(lctx, gf, n_threads, threadpool);
17393-
switch (compute_status) {
17394-
case GGML_STATUS_SUCCESS:
17395-
break;
17396-
case GGML_STATUS_ABORTED:
17397-
return 2;
17398-
case GGML_STATUS_ALLOC_FAILED:
17399-
return -2;
17400-
case GGML_STATUS_FAILED:
17401-
default:
17402-
return -3;
17445+
if (compute_status != GGML_STATUS_SUCCESS) {
17446+
llama_kv_cache_slot_restore(kv_slot_restorer, kv_self);
17447+
switch (compute_status) {
17448+
case GGML_STATUS_ABORTED:
17449+
return 2;
17450+
case GGML_STATUS_ALLOC_FAILED:
17451+
return -2;
17452+
case GGML_STATUS_FAILED:
17453+
default:
17454+
return -3;
17455+
}
1740317456
}
1740417457

1740517458
// update the kv ring buffer

0 commit comments

Comments
 (0)