@@ -2811,6 +2811,42 @@ struct llama_kv_cache {
2811
2811
}
2812
2812
};
2813
2813
2814
+ class llama_kv_cache_state {
2815
+ struct llama_kv_cache_state_short {
2816
+ uint32_t head = 0;
2817
+ uint32_t size = 0;
2818
+ uint32_t used = 0;
2819
+ uint32_t n = 0;
2820
+
2821
+ std::vector<llama_kv_cell> cells;
2822
+ } old_state;
2823
+
2824
+ bool saved = false;
2825
+
2826
+ public:
2827
+ void save_state(const llama_kv_cache& cache) {
2828
+ old_state.head = cache.head;
2829
+ old_state.size = cache.size;
2830
+ old_state.used = cache.used;
2831
+ old_state.n = cache.n;
2832
+ old_state.cells = cache.cells;
2833
+
2834
+ saved = true;
2835
+ }
2836
+
2837
+ void restore(llama_kv_cache& cache) {
2838
+ if (saved) {
2839
+ cache.head = old_state.head;
2840
+ cache.size = old_state.size;
2841
+ cache.used = old_state.used;
2842
+ cache.n = old_state.n;
2843
+ cache.cells = std::move(old_state.cells);
2844
+
2845
+ saved = false;
2846
+ }
2847
+ }
2848
+ };
2849
+
2814
2850
struct llama_control_vector {
2815
2851
std::vector<struct ggml_tensor *> tensors; // per layer
2816
2852
std::vector<ggml_context_ptr> ctxs;
@@ -17256,6 +17292,7 @@ static int llama_decode_internal(
17256
17292
lctx.n_queued_tokens += n_tokens_all;
17257
17293
17258
17294
auto & kv_self = lctx.kv_self;
17295
+ llama_kv_cache_state kv_cache_state_holder;
17259
17296
17260
17297
const int64_t n_embd = hparams.n_embd;
17261
17298
const int64_t n_vocab = hparams.n_vocab;
@@ -17333,6 +17370,7 @@ static int llama_decode_internal(
17333
17370
// non-causal masks do not use the KV cache
17334
17371
if (hparams.causal_attn) {
17335
17372
llama_kv_cache_update(&lctx);
17373
+ kv_cache_state_holder.save_state(kv_self);
17336
17374
17337
17375
// if we have enough unused cells before the current head ->
17338
17376
// better to start searching from the beginning of the cache, hoping to fill it
@@ -17390,16 +17428,17 @@ static int llama_decode_internal(
17390
17428
llama_set_inputs(lctx, ubatch);
17391
17429
17392
17430
const auto compute_status = llama_graph_compute(lctx, gf, n_threads, threadpool);
17393
- switch (compute_status) {
17394
- case GGML_STATUS_SUCCESS:
17395
- break;
17396
- case GGML_STATUS_ABORTED:
17397
- return 2;
17398
- case GGML_STATUS_ALLOC_FAILED:
17399
- return -2;
17400
- case GGML_STATUS_FAILED:
17401
- default:
17402
- return -3;
17431
+ if (compute_status != GGML_STATUS_SUCCESS) {
17432
+ kv_cache_state_holder.restore(kv_self);
17433
+ switch (compute_status) {
17434
+ case GGML_STATUS_ABORTED:
17435
+ return 2;
17436
+ case GGML_STATUS_ALLOC_FAILED:
17437
+ return -2;
17438
+ case GGML_STATUS_FAILED:
17439
+ default:
17440
+ return -3;
17441
+ }
17403
17442
}
17404
17443
17405
17444
// update the kv ring buffer
0 commit comments