Skip to content

Commit 4701893

Browse files
committed
llama: reverting kv_cache in case of failed compute
1 parent 5e354e3 commit 4701893

File tree

1 file changed

+49
-10
lines changed

1 file changed

+49
-10
lines changed

src/llama.cpp

Lines changed: 49 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2811,6 +2811,42 @@ struct llama_kv_cache {
28112811
}
28122812
};
28132813

2814+
class llama_kv_cache_state {
2815+
struct llama_kv_cache_state_short {
2816+
uint32_t head = 0;
2817+
uint32_t size = 0;
2818+
uint32_t used = 0;
2819+
uint32_t n = 0;
2820+
2821+
std::vector<llama_kv_cell> cells;
2822+
} old_state;
2823+
2824+
bool saved = false;
2825+
2826+
public:
2827+
void save_state(const llama_kv_cache& cache) {
2828+
old_state.head = cache.head;
2829+
old_state.size = cache.size;
2830+
old_state.used = cache.used;
2831+
old_state.n = cache.n;
2832+
old_state.cells = cache.cells;
2833+
2834+
saved = true;
2835+
}
2836+
2837+
void restore(llama_kv_cache& cache) {
2838+
if (saved) {
2839+
cache.head = old_state.head;
2840+
cache.size = old_state.size;
2841+
cache.used = old_state.used;
2842+
cache.n = old_state.n;
2843+
cache.cells = std::move(old_state.cells);
2844+
2845+
saved = false;
2846+
}
2847+
}
2848+
};
2849+
28142850
struct llama_control_vector {
28152851
std::vector<struct ggml_tensor *> tensors; // per layer
28162852
std::vector<ggml_context_ptr> ctxs;
@@ -17256,6 +17292,7 @@ static int llama_decode_internal(
1725617292
lctx.n_queued_tokens += n_tokens_all;
1725717293

1725817294
auto & kv_self = lctx.kv_self;
17295+
llama_kv_cache_state kv_cache_state_holder;
1725917296

1726017297
const int64_t n_embd = hparams.n_embd;
1726117298
const int64_t n_vocab = hparams.n_vocab;
@@ -17333,6 +17370,7 @@ static int llama_decode_internal(
1733317370
// non-causal masks do not use the KV cache
1733417371
if (hparams.causal_attn) {
1733517372
llama_kv_cache_update(&lctx);
17373+
kv_cache_state_holder.save_state(kv_self);
1733617374

1733717375
// if we have enough unused cells before the current head ->
1733817376
// better to start searching from the beginning of the cache, hoping to fill it
@@ -17390,16 +17428,17 @@ static int llama_decode_internal(
1739017428
llama_set_inputs(lctx, ubatch);
1739117429

1739217430
const auto compute_status = llama_graph_compute(lctx, gf, n_threads, threadpool);
17393-
switch (compute_status) {
17394-
case GGML_STATUS_SUCCESS:
17395-
break;
17396-
case GGML_STATUS_ABORTED:
17397-
return 2;
17398-
case GGML_STATUS_ALLOC_FAILED:
17399-
return -2;
17400-
case GGML_STATUS_FAILED:
17401-
default:
17402-
return -3;
17431+
if (compute_status != GGML_STATUS_SUCCESS) {
17432+
kv_cache_state_holder.restore(kv_self);
17433+
switch (compute_status) {
17434+
case GGML_STATUS_ABORTED:
17435+
return 2;
17436+
case GGML_STATUS_ALLOC_FAILED:
17437+
return -2;
17438+
case GGML_STATUS_FAILED:
17439+
default:
17440+
return -3;
17441+
}
1740317442
}
1740417443

1740517444
// update the kv ring buffer

0 commit comments

Comments
 (0)