@@ -2811,6 +2811,22 @@ struct llama_kv_cache {
2811
2811
}
2812
2812
};
2813
2813
2814
+ // saves the kv_cache state for future recovery
2815
+ // used to preserve the kv_cache state before searching for a slot
2816
+ struct llama_kv_slot_restorer {
2817
+ struct llama_kv_cache_state {
2818
+ uint32_t head = 0;
2819
+ uint32_t size = 0;
2820
+ uint32_t used = 0;
2821
+ uint32_t n = 0;
2822
+ } old_state;
2823
+
2824
+ std::vector<llama_kv_cell> recurrent_cells; // for recurrent models only
2825
+ std::pair<uint32_t, uint32_t> slot_boundaries; // for non-recurrent models only
2826
+
2827
+ bool restore = false;
2828
+ };
2829
+
2814
2830
struct llama_control_vector {
2815
2831
std::vector<struct ggml_tensor *> tensors; // per layer
2816
2832
std::vector<ggml_context_ptr> ctxs;
@@ -3508,11 +3524,19 @@ static bool llama_kv_cache_init(
3508
3524
// to the first cell of the slot.
3509
3525
static bool llama_kv_cache_find_slot(
3510
3526
struct llama_kv_cache & cache,
3511
- const struct llama_ubatch & batch) {
3527
+ const struct llama_ubatch & batch,
3528
+ struct llama_kv_slot_restorer * slot_restorer = nullptr) {
3512
3529
const uint32_t n_tokens = batch.n_tokens;
3513
3530
const uint32_t n_seqs = batch.n_seqs;
3514
3531
const uint32_t n_seq_tokens = batch.n_seq_tokens;
3515
3532
3533
+ if (slot_restorer != nullptr) {
3534
+ slot_restorer->old_state.head = cache.head;
3535
+ slot_restorer->old_state.size = cache.size;
3536
+ slot_restorer->old_state.used = cache.used;
3537
+ slot_restorer->old_state.n = cache.n;
3538
+ }
3539
+
3516
3540
if (cache.recurrent) {
3517
3541
// For recurrent state architectures (like Mamba or RWKV),
3518
3542
// each cache cell can store the state for a whole sequence.
@@ -3521,6 +3545,11 @@ static bool llama_kv_cache_find_slot(
3521
3545
// can only process batches with an equal number of new tokens in each sequence
3522
3546
GGML_ASSERT(batch.equal_seqs);
3523
3547
3548
+ if (slot_restorer != nullptr) {
3549
+ slot_restorer->recurrent_cells = cache.cells;
3550
+ slot_restorer->restore = true;
3551
+ }
3552
+
3524
3553
int32_t min = cache.size - 1;
3525
3554
int32_t max = 0;
3526
3555
@@ -3709,6 +3738,11 @@ static bool llama_kv_cache_find_slot(
3709
3738
}
3710
3739
}
3711
3740
3741
+ if (slot_restorer != nullptr) {
3742
+ slot_restorer->slot_boundaries = std::make_pair(cache.head, cache.head + n_tokens);
3743
+ slot_restorer->restore = true;
3744
+ }
3745
+
3712
3746
for (uint32_t s = 0; s < n_seqs; s++) {
3713
3747
for (uint32_t i = 0; i < n_seq_tokens; ++i) {
3714
3748
uint32_t k = s*n_seq_tokens + i;
@@ -3998,6 +4032,23 @@ static uint32_t llama_kv_cache_get_padding(const struct llama_cparams & cparams)
3998
4032
return cparams.flash_attn ? 256u : 32u;
3999
4033
}
4000
4034
4035
+ static void llama_kv_cache_slot_restore(
4036
+ const struct llama_kv_slot_restorer & restorer,
4037
+ struct llama_kv_cache & cache) {
4038
+ if (restorer.restore) {
4039
+ cache.head = restorer.old_state.head;
4040
+ cache.size = restorer.old_state.size;
4041
+ cache.used = restorer.old_state.used;
4042
+ cache.n = restorer.old_state.n;
4043
+
4044
+ if (cache.recurrent) {
4045
+ cache.cells = restorer.recurrent_cells;
4046
+ } else {
4047
+ llama_kv_cache_seq_rm(cache, -1, restorer.slot_boundaries.first, restorer.slot_boundaries.second + 1);
4048
+ }
4049
+ }
4050
+ }
4051
+
4001
4052
//
4002
4053
// model loading and saving
4003
4054
//
@@ -17256,6 +17307,7 @@ static int llama_decode_internal(
17256
17307
lctx.n_queued_tokens += n_tokens_all;
17257
17308
17258
17309
auto & kv_self = lctx.kv_self;
17310
+ llama_kv_slot_restorer kv_slot_restorer;
17259
17311
17260
17312
const int64_t n_embd = hparams.n_embd;
17261
17313
const int64_t n_vocab = hparams.n_vocab;
@@ -17340,7 +17392,7 @@ static int llama_decode_internal(
17340
17392
kv_self.head = 0;
17341
17393
}
17342
17394
17343
- if (!llama_kv_cache_find_slot(kv_self, ubatch)) {
17395
+ if (!llama_kv_cache_find_slot(kv_self, ubatch, &kv_slot_restorer )) {
17344
17396
return 1;
17345
17397
}
17346
17398
@@ -17390,16 +17442,17 @@ static int llama_decode_internal(
17390
17442
llama_set_inputs(lctx, ubatch);
17391
17443
17392
17444
const auto compute_status = llama_graph_compute(lctx, gf, n_threads, threadpool);
17393
- switch (compute_status) {
17394
- case GGML_STATUS_SUCCESS:
17395
- break;
17396
- case GGML_STATUS_ABORTED:
17397
- return 2;
17398
- case GGML_STATUS_ALLOC_FAILED:
17399
- return -2;
17400
- case GGML_STATUS_FAILED:
17401
- default:
17402
- return -3;
17445
+ if (compute_status != GGML_STATUS_SUCCESS) {
17446
+ llama_kv_cache_slot_restore(kv_slot_restorer, kv_self);
17447
+ switch (compute_status) {
17448
+ case GGML_STATUS_ABORTED:
17449
+ return 2;
17450
+ case GGML_STATUS_ALLOC_FAILED:
17451
+ return -2;
17452
+ case GGML_STATUS_FAILED:
17453
+ default:
17454
+ return -3;
17455
+ }
17403
17456
}
17404
17457
17405
17458
// update the kv ring buffer
0 commit comments