@@ -2806,6 +2806,42 @@ struct llama_kv_cache {
2806
2806
}
2807
2807
};
2808
2808
2809
+ class llama_kv_cache_state {
2810
+ struct llama_kv_cache_state_short {
2811
+ uint32_t head = 0;
2812
+ uint32_t size = 0;
2813
+ uint32_t used = 0;
2814
+ uint32_t n = 0;
2815
+
2816
+ std::vector<llama_kv_cell> cells;
2817
+ } old_state;
2818
+
2819
+ bool saved = false;
2820
+
2821
+ public:
2822
+ void save_state(const llama_kv_cache& cache) {
2823
+ old_state.head = cache.head;
2824
+ old_state.size = cache.size;
2825
+ old_state.used = cache.used;
2826
+ old_state.n = cache.n;
2827
+ old_state.cells = cache.cells;
2828
+
2829
+ saved = true;
2830
+ }
2831
+
2832
+ void restore(llama_kv_cache& cache) {
2833
+ if (saved) {
2834
+ cache.head = old_state.head;
2835
+ cache.size = old_state.size;
2836
+ cache.used = old_state.used;
2837
+ cache.n = old_state.n;
2838
+ cache.cells = std::move(old_state.cells);
2839
+
2840
+ saved = false;
2841
+ }
2842
+ }
2843
+ };
2844
+
2809
2845
struct llama_control_vector {
2810
2846
std::vector<struct ggml_tensor *> tensors; // per layer
2811
2847
std::vector<struct ggml_context *> ctxs;
@@ -16687,6 +16723,7 @@ static int llama_decode_internal(
16687
16723
lctx.n_queued_tokens += n_tokens_all;
16688
16724
16689
16725
auto & kv_self = lctx.kv_self;
16726
+ llama_kv_cache_state kv_cache_state_holder;
16690
16727
16691
16728
const int64_t n_embd = hparams.n_embd;
16692
16729
const int64_t n_vocab = hparams.n_vocab;
@@ -16764,6 +16801,7 @@ static int llama_decode_internal(
16764
16801
// non-causal masks do not use the KV cache
16765
16802
if (hparams.causal_attn) {
16766
16803
llama_kv_cache_update(&lctx);
16804
+ kv_cache_state_holder.save_state(kv_self);
16767
16805
16768
16806
// if we have enough unused cells before the current head ->
16769
16807
// better to start searching from the beginning of the cache, hoping to fill it
@@ -16821,16 +16859,17 @@ static int llama_decode_internal(
16821
16859
llama_set_inputs(lctx, ubatch);
16822
16860
16823
16861
const auto compute_status = llama_graph_compute(lctx, gf, n_threads, threadpool);
16824
- switch (compute_status) {
16825
- case GGML_STATUS_SUCCESS:
16826
- break;
16827
- case GGML_STATUS_ABORTED:
16828
- return 2;
16829
- case GGML_STATUS_ALLOC_FAILED:
16830
- return -2;
16831
- case GGML_STATUS_FAILED:
16832
- default:
16833
- return -3;
16862
+ if (compute_status != GGML_STATUS_SUCCESS) {
16863
+ kv_cache_state_holder.restore(kv_self);
16864
+ switch (compute_status) {
16865
+ case GGML_STATUS_ABORTED:
16866
+ return 2;
16867
+ case GGML_STATUS_ALLOC_FAILED:
16868
+ return -2;
16869
+ case GGML_STATUS_FAILED:
16870
+ default:
16871
+ return -3;
16872
+ }
16834
16873
}
16835
16874
16836
16875
// update the kv ring buffer
0 commit comments