Skip to content

Commit 20510ea

Browse files
committed
llama: reverting kv_cache in case of failed compute
1 parent 326e4d9 commit 20510ea

File tree

1 file changed

+49
-10
lines changed

1 file changed

+49
-10
lines changed

src/llama.cpp

Lines changed: 49 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2806,6 +2806,42 @@ struct llama_kv_cache {
28062806
}
28072807
};
28082808

2809+
class llama_kv_cache_state {
2810+
struct llama_kv_cache_state_short {
2811+
uint32_t head = 0;
2812+
uint32_t size = 0;
2813+
uint32_t used = 0;
2814+
uint32_t n = 0;
2815+
2816+
std::vector<llama_kv_cell> cells;
2817+
} old_state;
2818+
2819+
bool saved = false;
2820+
2821+
public:
2822+
void save_state(const llama_kv_cache& cache) {
2823+
old_state.head = cache.head;
2824+
old_state.size = cache.size;
2825+
old_state.used = cache.used;
2826+
old_state.n = cache.n;
2827+
old_state.cells = cache.cells;
2828+
2829+
saved = true;
2830+
}
2831+
2832+
void restore(llama_kv_cache& cache) {
2833+
if (saved) {
2834+
cache.head = old_state.head;
2835+
cache.size = old_state.size;
2836+
cache.used = old_state.used;
2837+
cache.n = old_state.n;
2838+
cache.cells = std::move(old_state.cells);
2839+
2840+
saved = false;
2841+
}
2842+
}
2843+
};
2844+
28092845
struct llama_control_vector {
28102846
std::vector<struct ggml_tensor *> tensors; // per layer
28112847
std::vector<struct ggml_context *> ctxs;
@@ -16687,6 +16723,7 @@ static int llama_decode_internal(
1668716723
lctx.n_queued_tokens += n_tokens_all;
1668816724

1668916725
auto & kv_self = lctx.kv_self;
16726+
llama_kv_cache_state kv_cache_state_holder;
1669016727

1669116728
const int64_t n_embd = hparams.n_embd;
1669216729
const int64_t n_vocab = hparams.n_vocab;
@@ -16764,6 +16801,7 @@ static int llama_decode_internal(
1676416801
// non-causal masks do not use the KV cache
1676516802
if (hparams.causal_attn) {
1676616803
llama_kv_cache_update(&lctx);
16804+
kv_cache_state_holder.save_state(kv_self);
1676716805

1676816806
// if we have enough unused cells before the current head ->
1676916807
// better to start searching from the beginning of the cache, hoping to fill it
@@ -16821,16 +16859,17 @@ static int llama_decode_internal(
1682116859
llama_set_inputs(lctx, ubatch);
1682216860

1682316861
const auto compute_status = llama_graph_compute(lctx, gf, n_threads, threadpool);
16824-
switch (compute_status) {
16825-
case GGML_STATUS_SUCCESS:
16826-
break;
16827-
case GGML_STATUS_ABORTED:
16828-
return 2;
16829-
case GGML_STATUS_ALLOC_FAILED:
16830-
return -2;
16831-
case GGML_STATUS_FAILED:
16832-
default:
16833-
return -3;
16862+
if (compute_status != GGML_STATUS_SUCCESS) {
16863+
kv_cache_state_holder.restore(kv_self);
16864+
switch (compute_status) {
16865+
case GGML_STATUS_ABORTED:
16866+
return 2;
16867+
case GGML_STATUS_ALLOC_FAILED:
16868+
return -2;
16869+
case GGML_STATUS_FAILED:
16870+
default:
16871+
return -3;
16872+
}
1683416873
}
1683516874

1683616875
// update the kv ring buffer

0 commit comments

Comments
 (0)