@@ -2815,22 +2815,6 @@ struct llama_kv_cache {
2815
2815
}
2816
2816
};
2817
2817
2818
- // saves the kv_cache state for future recovery
2819
- // used to preserve the kv_cache state before searching for a slot
2820
- struct llama_kv_slot_restorer {
2821
- struct llama_kv_cache_state {
2822
- uint32_t head = 0;
2823
- uint32_t size = 0;
2824
- uint32_t used = 0;
2825
- uint32_t n = 0;
2826
- } old_state;
2827
-
2828
- std::vector<llama_kv_cell> recurrent_cells; // for recurrent models only
2829
- std::pair<uint32_t, uint32_t> slot_boundaries; // for non-recurrent models only
2830
-
2831
- bool restore = false;
2832
- };
2833
-
2834
2818
struct llama_control_vector {
2835
2819
std::vector<struct ggml_tensor *> tensors; // per layer
2836
2820
std::vector<struct ggml_context *> ctxs;
@@ -3666,21 +3650,24 @@ static bool llama_kv_cache_init(
3666
3650
// updates the cache head
3667
3651
// Note: On success, it's important that cache.head points
3668
3652
// to the first cell of the slot.
3669
- static bool llama_kv_cache_find_slot(
3653
+ struct llama_kv_cache_slot_info {
3654
+ std::pair<uint32_t, uint32_t> boundaries;
3655
+ bool found = false;
3656
+
3657
+ explicit llama_kv_cache_slot_info(bool found_) : found{found_} {}
3658
+ llama_kv_cache_slot_info(uint32_t begin, uint32_t end) : boundaries{begin, end}, found{true} {}
3659
+
3660
+ operator bool() const { return found; }
3661
+ };
3662
+ static const llama_kv_cache_slot_info llama_kv_cache_slot_info_failed{false};
3663
+
3664
+ static struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
3670
3665
struct llama_kv_cache & cache,
3671
- const struct llama_ubatch & batch,
3672
- struct llama_kv_slot_restorer * slot_restorer = nullptr) {
3666
+ const struct llama_ubatch & batch) {
3673
3667
const uint32_t n_tokens = batch.n_tokens;
3674
3668
const uint32_t n_seqs = batch.n_seqs;
3675
3669
const uint32_t n_seq_tokens = batch.n_seq_tokens;
3676
3670
3677
- if (slot_restorer != nullptr) {
3678
- slot_restorer->old_state.head = cache.head;
3679
- slot_restorer->old_state.size = cache.size;
3680
- slot_restorer->old_state.used = cache.used;
3681
- slot_restorer->old_state.n = cache.n;
3682
- }
3683
-
3684
3671
if (cache.recurrent) {
3685
3672
// For recurrent state architectures (like Mamba or RWKV),
3686
3673
// each cache cell can store the state for a whole sequence.
@@ -3689,11 +3676,6 @@ static bool llama_kv_cache_find_slot(
3689
3676
// can only process batches with an equal number of new tokens in each sequence
3690
3677
GGML_ASSERT(batch.equal_seqs);
3691
3678
3692
- if (slot_restorer != nullptr) {
3693
- slot_restorer->recurrent_cells = cache.cells;
3694
- slot_restorer->restore = true;
3695
- }
3696
-
3697
3679
int32_t min = cache.size - 1;
3698
3680
int32_t max = 0;
3699
3681
@@ -3707,7 +3689,7 @@ static bool llama_kv_cache_find_slot(
3707
3689
// too big seq_id
3708
3690
// TODO: would it be possible to resize the cache instead?
3709
3691
LLAMA_LOG_ERROR("%s: seq_id=%d >= n_seq_max=%d Try using a bigger --parallel value\n", __func__, seq_id, cache.size);
3710
- return false ;
3692
+ return llama_kv_cache_slot_info_failed ;
3711
3693
}
3712
3694
if (j > 0) {
3713
3695
llama_kv_cell & seq = cache.cells[seq_id];
@@ -3842,15 +3824,17 @@ static bool llama_kv_cache_find_slot(
3842
3824
// allow getting the range of used cells, from head to head + n
3843
3825
cache.head = min;
3844
3826
cache.n = max - min + 1;
3827
+ cache.used = std::count_if(cache.cells.begin(), cache.cells.end(),
3828
+ [](const llama_kv_cell& cell){ return !cell.is_empty(); });
3845
3829
3846
3830
// sanity check
3847
- return cache.n >= n_seqs;
3831
+ return llama_kv_cache_slot_info( cache.n >= n_seqs) ;
3848
3832
}
3849
3833
// otherwise, one cell per token.
3850
3834
3851
3835
if (n_tokens > cache.size) {
3852
3836
LLAMA_LOG_ERROR("%s: n_tokens=%d > cache.size=%d\n", __func__, n_tokens, cache.size);
3853
- return false ;
3837
+ return llama_kv_cache_slot_info_failed ;
3854
3838
}
3855
3839
3856
3840
uint32_t n_tested = 0;
@@ -3878,15 +3862,10 @@ static bool llama_kv_cache_find_slot(
3878
3862
3879
3863
if (n_tested >= cache.size) {
3880
3864
//LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens);
3881
- return false ;
3865
+ return llama_kv_cache_slot_info_failed ;
3882
3866
}
3883
3867
}
3884
3868
3885
- if (slot_restorer != nullptr) {
3886
- slot_restorer->slot_boundaries = std::make_pair(cache.head, cache.head + n_tokens);
3887
- slot_restorer->restore = true;
3888
- }
3889
-
3890
3869
for (uint32_t s = 0; s < n_seqs; s++) {
3891
3870
for (uint32_t i = 0; i < n_seq_tokens; ++i) {
3892
3871
uint32_t k = s*n_seq_tokens + i;
@@ -3900,7 +3879,7 @@ static bool llama_kv_cache_find_slot(
3900
3879
3901
3880
cache.used += n_tokens;
3902
3881
3903
- return true ;
3882
+ return llama_kv_cache_slot_info(cache.head, cache.head + n_tokens) ;
3904
3883
}
3905
3884
3906
3885
// find how many cells are currently in use
@@ -4176,22 +4155,47 @@ static uint32_t llama_kv_cache_get_padding(const struct llama_cparams & cparams)
4176
4155
return cparams.flash_attn ? 256u : 32u;
4177
4156
}
4178
4157
4179
- static void llama_kv_cache_slot_restore(
4180
- const struct llama_kv_slot_restorer & restorer,
4181
- struct llama_kv_cache & cache) {
4182
- if (restorer.restore) {
4183
- cache.head = restorer.old_state.head;
4184
- cache.size = restorer.old_state.size;
4185
- cache.used = restorer.old_state.used;
4186
- cache.n = restorer.old_state.n;
4187
-
4188
- if (cache.recurrent) {
4189
- cache.cells = restorer.recurrent_cells;
4190
- } else {
4191
- llama_kv_cache_seq_rm(cache, -1, restorer.slot_boundaries.first, restorer.slot_boundaries.second + 1);
4158
+ // saves the kv_cache state for future recovery.
4159
+ // used to rollback llama_kv_cache_find_slot changes.
4160
+ struct llama_kv_slot_restorer {
4161
+ struct llama_kv_cache_state {
4162
+ uint32_t head = 0;
4163
+ uint32_t n = 0;
4164
+ } old_state;
4165
+
4166
+ std::vector<std::pair<uint32_t, uint32_t>> slot_boundaries; // for non-recurrent models only
4167
+
4168
+ bool do_restore = false;
4169
+
4170
+ explicit llama_kv_slot_restorer(const struct llama_kv_cache & cache) {
4171
+ old_state.head = cache.head;
4172
+ old_state.n = cache.n;
4173
+ }
4174
+
4175
+ void save(const struct llama_kv_cache_slot_info& slot) {
4176
+ if (slot) {
4177
+ do_restore = true;
4178
+ if (slot.boundaries.first != slot.boundaries.second) {
4179
+ slot_boundaries.push_back(slot.boundaries);
4180
+ }
4192
4181
}
4193
4182
}
4194
- }
4183
+
4184
+ void restore(struct llama_kv_cache & cache) {
4185
+ if (do_restore) {
4186
+ cache.head = old_state.head;
4187
+ cache.n = old_state.n;
4188
+
4189
+ if (cache.recurrent) { // recurrent models like Mamba or RWKV can't have a state partially erased
4190
+ llama_kv_cache_seq_rm(cache, -1, -1, -1);
4191
+ } else {
4192
+ for (auto & slot : slot_boundaries) {
4193
+ llama_kv_cache_seq_rm(cache, -1, slot.first, slot.second);
4194
+ }
4195
+ }
4196
+ }
4197
+ }
4198
+ };
4195
4199
4196
4200
//
4197
4201
// model loading and saving
@@ -17235,7 +17239,7 @@ static int llama_decode_internal(
17235
17239
lctx.n_queued_tokens += n_tokens_all;
17236
17240
17237
17241
auto & kv_self = lctx.kv_self;
17238
- llama_kv_slot_restorer kv_slot_restorer;
17242
+ llama_kv_slot_restorer kv_slot_restorer(kv_self) ;
17239
17243
17240
17244
const int64_t n_embd = hparams.n_embd;
17241
17245
const int64_t n_vocab = hparams.n_vocab;
@@ -17320,9 +17324,11 @@ static int llama_decode_internal(
17320
17324
kv_self.head = 0;
17321
17325
}
17322
17326
17323
- if (!llama_kv_cache_find_slot(kv_self, ubatch, &kv_slot_restorer)) {
17327
+ const auto slot = llama_kv_cache_find_slot(kv_self, ubatch);
17328
+ if (!slot) {
17324
17329
return 1;
17325
17330
}
17331
+ kv_slot_restorer.save(slot);
17326
17332
17327
17333
if (!kv_self.recurrent) {
17328
17334
// a heuristic, to avoid attending the full cache if it is not yet utilized
@@ -17371,7 +17377,7 @@ static int llama_decode_internal(
17371
17377
17372
17378
const auto compute_status = llama_graph_compute(lctx, gf, n_threads, threadpool);
17373
17379
if (compute_status != GGML_STATUS_SUCCESS) {
17374
- llama_kv_cache_slot_restore( kv_slot_restorer, kv_self);
17380
+ kv_slot_restorer.restore( kv_self);
17375
17381
switch (compute_status) {
17376
17382
case GGML_STATUS_ABORTED:
17377
17383
return 2;
0 commit comments