Skip to content

Commit bbf27cc

Browse files
committed
llama: correct reverting of the entire batch.
also updates `llama_kv_cache_find_slot`, will correctly count the number of `used` cells for recurrent models
1 parent 0c05c60 commit bbf27cc

File tree

1 file changed

+64
-58
lines changed

1 file changed

+64
-58
lines changed

src/llama.cpp

Lines changed: 64 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -2815,22 +2815,6 @@ struct llama_kv_cache {
28152815
}
28162816
};
28172817

2818-
// saves the kv_cache state for future recovery
2819-
// used to preserve the kv_cache state before searching for a slot
2820-
struct llama_kv_slot_restorer {
2821-
struct llama_kv_cache_state {
2822-
uint32_t head = 0;
2823-
uint32_t size = 0;
2824-
uint32_t used = 0;
2825-
uint32_t n = 0;
2826-
} old_state;
2827-
2828-
std::vector<llama_kv_cell> recurrent_cells; // for recurrent models only
2829-
std::pair<uint32_t, uint32_t> slot_boundaries; // for non-recurrent models only
2830-
2831-
bool restore = false;
2832-
};
2833-
28342818
struct llama_control_vector {
28352819
std::vector<struct ggml_tensor *> tensors; // per layer
28362820
std::vector<struct ggml_context *> ctxs;
@@ -3666,21 +3650,24 @@ static bool llama_kv_cache_init(
36663650
// updates the cache head
36673651
// Note: On success, it's important that cache.head points
36683652
// to the first cell of the slot.
3669-
static bool llama_kv_cache_find_slot(
3653+
struct llama_kv_cache_slot_info {
3654+
std::pair<uint32_t, uint32_t> boundaries;
3655+
bool found = false;
3656+
3657+
explicit llama_kv_cache_slot_info(bool found_) : found{found_} {}
3658+
llama_kv_cache_slot_info(uint32_t begin, uint32_t end) : boundaries{begin, end}, found{true} {}
3659+
3660+
operator bool() const { return found; }
3661+
};
3662+
static const llama_kv_cache_slot_info llama_kv_cache_slot_info_failed{false};
3663+
3664+
static struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
36703665
struct llama_kv_cache & cache,
3671-
const struct llama_ubatch & batch,
3672-
struct llama_kv_slot_restorer * slot_restorer = nullptr) {
3666+
const struct llama_ubatch & batch) {
36733667
const uint32_t n_tokens = batch.n_tokens;
36743668
const uint32_t n_seqs = batch.n_seqs;
36753669
const uint32_t n_seq_tokens = batch.n_seq_tokens;
36763670

3677-
if (slot_restorer != nullptr) {
3678-
slot_restorer->old_state.head = cache.head;
3679-
slot_restorer->old_state.size = cache.size;
3680-
slot_restorer->old_state.used = cache.used;
3681-
slot_restorer->old_state.n = cache.n;
3682-
}
3683-
36843671
if (cache.recurrent) {
36853672
// For recurrent state architectures (like Mamba or RWKV),
36863673
// each cache cell can store the state for a whole sequence.
@@ -3689,11 +3676,6 @@ static bool llama_kv_cache_find_slot(
36893676
// can only process batches with an equal number of new tokens in each sequence
36903677
GGML_ASSERT(batch.equal_seqs);
36913678

3692-
if (slot_restorer != nullptr) {
3693-
slot_restorer->recurrent_cells = cache.cells;
3694-
slot_restorer->restore = true;
3695-
}
3696-
36973679
int32_t min = cache.size - 1;
36983680
int32_t max = 0;
36993681

@@ -3707,7 +3689,7 @@ static bool llama_kv_cache_find_slot(
37073689
// too big seq_id
37083690
// TODO: would it be possible to resize the cache instead?
37093691
LLAMA_LOG_ERROR("%s: seq_id=%d >= n_seq_max=%d Try using a bigger --parallel value\n", __func__, seq_id, cache.size);
3710-
return false;
3692+
return llama_kv_cache_slot_info_failed;
37113693
}
37123694
if (j > 0) {
37133695
llama_kv_cell & seq = cache.cells[seq_id];
@@ -3842,15 +3824,17 @@ static bool llama_kv_cache_find_slot(
38423824
// allow getting the range of used cells, from head to head + n
38433825
cache.head = min;
38443826
cache.n = max - min + 1;
3827+
cache.used = std::count_if(cache.cells.begin(), cache.cells.end(),
3828+
[](const llama_kv_cell& cell){ return !cell.is_empty(); });
38453829

38463830
// sanity check
3847-
return cache.n >= n_seqs;
3831+
return llama_kv_cache_slot_info(cache.n >= n_seqs);
38483832
}
38493833
// otherwise, one cell per token.
38503834

38513835
if (n_tokens > cache.size) {
38523836
LLAMA_LOG_ERROR("%s: n_tokens=%d > cache.size=%d\n", __func__, n_tokens, cache.size);
3853-
return false;
3837+
return llama_kv_cache_slot_info_failed;
38543838
}
38553839

38563840
uint32_t n_tested = 0;
@@ -3878,15 +3862,10 @@ static bool llama_kv_cache_find_slot(
38783862

38793863
if (n_tested >= cache.size) {
38803864
//LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens);
3881-
return false;
3865+
return llama_kv_cache_slot_info_failed;
38823866
}
38833867
}
38843868

3885-
if (slot_restorer != nullptr) {
3886-
slot_restorer->slot_boundaries = std::make_pair(cache.head, cache.head + n_tokens);
3887-
slot_restorer->restore = true;
3888-
}
3889-
38903869
for (uint32_t s = 0; s < n_seqs; s++) {
38913870
for (uint32_t i = 0; i < n_seq_tokens; ++i) {
38923871
uint32_t k = s*n_seq_tokens + i;
@@ -3900,7 +3879,7 @@ static bool llama_kv_cache_find_slot(
39003879

39013880
cache.used += n_tokens;
39023881

3903-
return true;
3882+
return llama_kv_cache_slot_info(cache.head, cache.head + n_tokens);
39043883
}
39053884

39063885
// find how many cells are currently in use
@@ -4176,22 +4155,47 @@ static uint32_t llama_kv_cache_get_padding(const struct llama_cparams & cparams)
41764155
return cparams.flash_attn ? 256u : 32u;
41774156
}
41784157

4179-
static void llama_kv_cache_slot_restore(
4180-
const struct llama_kv_slot_restorer & restorer,
4181-
struct llama_kv_cache & cache) {
4182-
if (restorer.restore) {
4183-
cache.head = restorer.old_state.head;
4184-
cache.size = restorer.old_state.size;
4185-
cache.used = restorer.old_state.used;
4186-
cache.n = restorer.old_state.n;
4187-
4188-
if (cache.recurrent) {
4189-
cache.cells = restorer.recurrent_cells;
4190-
} else {
4191-
llama_kv_cache_seq_rm(cache, -1, restorer.slot_boundaries.first, restorer.slot_boundaries.second + 1);
4158+
// saves the kv_cache state for future recovery.
4159+
// used to rollback llama_kv_cache_find_slot changes.
4160+
struct llama_kv_slot_restorer {
4161+
struct llama_kv_cache_state {
4162+
uint32_t head = 0;
4163+
uint32_t n = 0;
4164+
} old_state;
4165+
4166+
std::vector<std::pair<uint32_t, uint32_t>> slot_boundaries; // for non-recurrent models only
4167+
4168+
bool do_restore = false;
4169+
4170+
explicit llama_kv_slot_restorer(const struct llama_kv_cache & cache) {
4171+
old_state.head = cache.head;
4172+
old_state.n = cache.n;
4173+
}
4174+
4175+
void save(const struct llama_kv_cache_slot_info& slot) {
4176+
if (slot) {
4177+
do_restore = true;
4178+
if (slot.boundaries.first != slot.boundaries.second) {
4179+
slot_boundaries.push_back(slot.boundaries);
4180+
}
41924181
}
41934182
}
4194-
}
4183+
4184+
void restore(struct llama_kv_cache & cache) {
4185+
if (do_restore) {
4186+
cache.head = old_state.head;
4187+
cache.n = old_state.n;
4188+
4189+
if (cache.recurrent) { // recurrent models like Mamba or RWKV can't have a state partially erased
4190+
llama_kv_cache_seq_rm(cache, -1, -1, -1);
4191+
} else {
4192+
for (auto & slot : slot_boundaries) {
4193+
llama_kv_cache_seq_rm(cache, -1, slot.first, slot.second);
4194+
}
4195+
}
4196+
}
4197+
}
4198+
};
41954199

41964200
//
41974201
// model loading and saving
@@ -17235,7 +17239,7 @@ static int llama_decode_internal(
1723517239
lctx.n_queued_tokens += n_tokens_all;
1723617240

1723717241
auto & kv_self = lctx.kv_self;
17238-
llama_kv_slot_restorer kv_slot_restorer;
17242+
llama_kv_slot_restorer kv_slot_restorer(kv_self);
1723917243

1724017244
const int64_t n_embd = hparams.n_embd;
1724117245
const int64_t n_vocab = hparams.n_vocab;
@@ -17320,9 +17324,11 @@ static int llama_decode_internal(
1732017324
kv_self.head = 0;
1732117325
}
1732217326

17323-
if (!llama_kv_cache_find_slot(kv_self, ubatch, &kv_slot_restorer)) {
17327+
const auto slot = llama_kv_cache_find_slot(kv_self, ubatch);
17328+
if (!slot) {
1732417329
return 1;
1732517330
}
17331+
kv_slot_restorer.save(slot);
1732617332

1732717333
if (!kv_self.recurrent) {
1732817334
// a heuristic, to avoid attending the full cache if it is not yet utilized
@@ -17371,7 +17377,7 @@ static int llama_decode_internal(
1737117377

1737217378
const auto compute_status = llama_graph_compute(lctx, gf, n_threads, threadpool);
1737317379
if (compute_status != GGML_STATUS_SUCCESS) {
17374-
llama_kv_cache_slot_restore(kv_slot_restorer, kv_self);
17380+
kv_slot_restorer.restore(kv_self);
1737517381
switch (compute_status) {
1737617382
case GGML_STATUS_ABORTED:
1737717383
return 2;

0 commit comments

Comments
 (0)