Skip to content

Commit 3e63a58

Browse files
authored
kv-cache : refactor the update/defrag mechanism (#13988)
* kv-cache : refactor update mechanism ggml-ci * memory : improve status handling * defrag : reset head + add comments ggml-ci * cont : minor fixes ggml-ci
1 parent 2589ad3 commit 3e63a58

11 files changed

+336
-187
lines changed

src/llama-context.cpp

Lines changed: 56 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -429,30 +429,62 @@ const llama_kv_cache * llama_context::get_kv_self() const {
429429
return kv_self;
430430
}
431431

432-
bool llama_context::kv_self_update() {
432+
void llama_context::kv_self_defrag_sched() {
433+
if (!memory) {
434+
return;
435+
}
436+
437+
memory_force_optimize = true;
438+
}
439+
440+
bool llama_context::kv_self_update(bool optimize) {
433441
if (!memory) {
434442
return false;
435443
}
436444

437445
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
438446

439-
if (!kv_self->update(*this)) {
440-
// no updates have been performed
441-
return false;
447+
{
448+
// TODO: remove in the future
449+
optimize |= memory_force_optimize;
450+
memory_force_optimize = false;
451+
452+
const auto kv_state = kv_self->init_update(this, optimize);
453+
switch (kv_state->get_status()) {
454+
case LLAMA_MEMORY_STATUS_SUCCESS:
455+
{
456+
// noop
457+
} break;
458+
case LLAMA_MEMORY_STATUS_NO_UPDATE:
459+
{
460+
// no updates need to be performed
461+
return false;
462+
}
463+
case LLAMA_MEMORY_STATUS_FAILED_PREPARE:
464+
case LLAMA_MEMORY_STATUS_FAILED_COMPUTE:
465+
{
466+
LLAMA_LOG_ERROR("%s: failed to prepare memory update\n", __func__);
467+
return false;
468+
}
469+
}
470+
471+
if (!kv_state->apply()) {
472+
LLAMA_LOG_ERROR("%s: failed to apply memory update\n", __func__);
473+
}
442474
}
443475

444476
// if the KV cache did any computation, we have to reserve a new worst-case graph
445477
const auto kv_state = kv_self->init_full();
446478
if (!kv_state) {
447-
throw std::runtime_error("failed to initialize KV cache");
479+
throw std::runtime_error("failed to initialize memory state");
448480
}
449481

450482
const uint32_t n_seqs = cparams.n_seq_max;
451483
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
452484

453485
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, kv_state.get());
454486
if (!gf) {
455-
LLAMA_LOG_ERROR("%s: failed to reserve graph after the KV cache update\n", __func__);
487+
LLAMA_LOG_ERROR("%s: failed to reserve graph after the memory update\n", __func__);
456488
}
457489

458490
return true;
@@ -940,13 +972,13 @@ int llama_context::decode(llama_batch & inp_batch) {
940972
n_outputs_all = 1;
941973
}
942974

975+
bool did_optimize = false;
976+
943977
// handle any pending defrags/shifts
944-
kv_self_update();
978+
kv_self_update(false);
945979

946980
llama_memory_state_ptr kv_state;
947981

948-
bool did_defrag = false;
949-
950982
while (true) {
951983
kv_state = kv_self->init_batch(batch, cparams.n_ubatch, embd_pooled, /* logits_all */ n_outputs_all == n_tokens_all);
952984
if (!kv_state) {
@@ -957,25 +989,32 @@ int llama_context::decode(llama_batch & inp_batch) {
957989
case LLAMA_MEMORY_STATUS_SUCCESS:
958990
{
959991
} break;
992+
case LLAMA_MEMORY_STATUS_NO_UPDATE:
993+
{
994+
LLAMA_LOG_ERROR("%s: unexpected memory state status: %d\n", __func__, kv_state->get_status());
995+
996+
return -2;
997+
}
960998
case LLAMA_MEMORY_STATUS_FAILED_PREPARE:
961999
{
962-
if (!did_defrag) {
963-
did_defrag = true;
1000+
if (!did_optimize) {
1001+
did_optimize = true;
9641002

965-
kv_self->defrag_sched(-1.0f);
966-
if (kv_self_update()) {
967-
LLAMA_LOG_DEBUG("%s: failed to init batch of size %d, retrying after defrag\n", __func__, batch.n_tokens);
1003+
if (kv_self_update(true)) {
1004+
LLAMA_LOG_DEBUG("%s: retrying batch size %d after cache optimization\n", __func__, batch.n_tokens);
9681005

9691006
continue;
9701007
}
9711008
}
9721009

973-
LLAMA_LOG_WARN("%s: failed to find KV cache slot for batch of size %d\n", __func__, batch.n_tokens);
1010+
LLAMA_LOG_WARN("%s: failed to find a memory slot for batch of size %d\n", __func__, batch.n_tokens);
9741011

9751012
return 1;
9761013
}
9771014
case LLAMA_MEMORY_STATUS_FAILED_COMPUTE:
9781015
{
1016+
LLAMA_LOG_ERROR("%s: compute failed while preparing batch of size %d\n", __func__, batch.n_tokens);
1017+
9791018
return -2;
9801019
}
9811020
}
@@ -1189,11 +1228,6 @@ int llama_context::decode(llama_batch & inp_batch) {
11891228
// wait for the computation to finish (automatically done when obtaining the model output)
11901229
//synchronize();
11911230

1192-
// decide if we need to defrag the kv cache
1193-
if (cparams.defrag_thold > 0.0f) {
1194-
kv_self->defrag_sched(cparams.defrag_thold);
1195-
}
1196-
11971231
// Reset state for the next token before backend sync, to allow the CPU activities in the reset to
11981232
// overlap with device computation.
11991233
ggml_backend_sched_reset(sched.get());
@@ -2283,7 +2317,7 @@ llama_kv_cache * llama_get_kv_self(llama_context * ctx) {
22832317

22842318
// deprecated
22852319
void llama_kv_self_update(llama_context * ctx) {
2286-
ctx->kv_self_update();
2320+
ctx->kv_self_update(false);
22872321
}
22882322

22892323
enum llama_pooling_type llama_pooling_type(const llama_context * ctx) {
@@ -2538,13 +2572,8 @@ llama_pos llama_kv_self_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) {
25382572

25392573
// deprecated
25402574
void llama_kv_self_defrag(llama_context * ctx) {
2541-
auto * kv = ctx->get_kv_self();
2542-
if (!kv) {
2543-
return;
2544-
}
2545-
25462575
// force defrag
2547-
kv->defrag_sched(-1.0f);
2576+
ctx->kv_self_defrag_sched();
25482577
}
25492578

25502579
bool llama_kv_self_can_shift(const llama_context * ctx) {

src/llama-context.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,8 @@ struct llama_context {
5252

5353
// return true of the KV cache was updated
5454
// TODO: remove
55-
bool kv_self_update();
55+
bool kv_self_update(bool optimize);
56+
void kv_self_defrag_sched();
5657

5758
enum llama_pooling_type pooling_type() const;
5859

@@ -231,6 +232,9 @@ struct llama_context {
231232

232233
std::unique_ptr<llama_memory_i> memory;
233234

235+
// TODO: temporary, until the llama_kv_self_defrag() API is removed
236+
bool memory_force_optimize = false;
237+
234238
// decode output (2-dimensional array: [n_outputs][n_vocab])
235239
size_t logits_size = 0; // capacity (of floats) for logits
236240
float * logits = nullptr;

src/llama-kv-cache-recurrent.cpp

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include "llama-kv-cache-recurrent.h"
22

33
#include "llama-impl.h"
4+
#include "llama-io.h"
45
#include "llama-batch.h"
56
#include "llama-model.h"
67

@@ -386,6 +387,13 @@ llama_memory_state_ptr llama_kv_cache_recurrent::init_full() {
386387
return std::make_unique<llama_kv_cache_recurrent_state>(LLAMA_MEMORY_STATUS_SUCCESS, this);
387388
}
388389

390+
llama_memory_state_ptr llama_kv_cache_recurrent::init_update(llama_context * lctx, bool optimize) {
391+
GGML_UNUSED(lctx);
392+
GGML_UNUSED(optimize);
393+
394+
return std::make_unique<llama_kv_cache_recurrent_state>(LLAMA_MEMORY_STATUS_NO_UPDATE);
395+
}
396+
389397
bool llama_kv_cache_recurrent::prepare(const std::vector<llama_ubatch> & ubatches) {
390398
// simply remember the full state because it is very small for this type of cache
391399
// TODO: optimize
@@ -419,17 +427,6 @@ bool llama_kv_cache_recurrent::prepare(const std::vector<llama_ubatch> & ubatche
419427
return success;
420428
}
421429

422-
bool llama_kv_cache_recurrent::update(llama_context & lctx) {
423-
GGML_UNUSED(lctx);
424-
// noop
425-
return false;
426-
}
427-
428-
void llama_kv_cache_recurrent::defrag_sched(float thold) {
429-
GGML_UNUSED(thold);
430-
// noop
431-
}
432-
433430
bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
434431
const uint32_t n_tokens = ubatch.n_tokens;
435432
const uint32_t n_seqs = ubatch.n_seqs;

src/llama-kv-cache-recurrent.h

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,7 @@ class llama_kv_cache_recurrent : public llama_kv_cache {
5252

5353
llama_memory_state_ptr init_full() override;
5454

55-
bool update(llama_context & lctx) override;
56-
57-
void defrag_sched(float thold) override;
55+
llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) override;
5856

5957
bool prepare(const std::vector<llama_ubatch> & ubatches);
6058

src/llama-kv-cache-unified-iswa.cpp

Lines changed: 31 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -123,26 +123,16 @@ llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(const llama_batch
123123

124124
assert(heads_base.size() == heads_swa.size());
125125

126-
return std::make_unique<llama_kv_cache_unified_iswa_state>(LLAMA_MEMORY_STATUS_SUCCESS,
126+
return std::make_unique<llama_kv_cache_unified_iswa_state>(
127127
this, std::move(sbatch), std::move(heads_base), std::move(heads_swa), std::move(ubatches));
128128
}
129129

130130
llama_memory_state_ptr llama_kv_cache_unified_iswa::init_full() {
131-
return std::make_unique<llama_kv_cache_unified_iswa_state>(LLAMA_MEMORY_STATUS_SUCCESS, this);
131+
return std::make_unique<llama_kv_cache_unified_iswa_state>(this);
132132
}
133133

134-
bool llama_kv_cache_unified_iswa::update(llama_context & lctx) {
135-
bool res = false;
136-
137-
res = res | kv_base->update(lctx);
138-
res = res | kv_swa ->update(lctx);
139-
140-
return res;
141-
}
142-
143-
void llama_kv_cache_unified_iswa::defrag_sched(float thold) {
144-
kv_base->defrag_sched(thold);
145-
kv_swa ->defrag_sched(thold);
134+
llama_memory_state_ptr llama_kv_cache_unified_iswa::init_update(llama_context * lctx, bool optimize) {
135+
return std::make_unique<llama_kv_cache_unified_iswa_state>(this, lctx, optimize);
146136
}
147137

148138
bool llama_kv_cache_unified_iswa::get_can_shift() const {
@@ -174,26 +164,38 @@ llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_swa() const {
174164
llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(llama_memory_status status) : status(status) {}
175165

176166
llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
177-
llama_memory_status status,
178-
llama_kv_cache_unified_iswa * kv) : status(status) {
179-
state_base.reset(new llama_kv_cache_unified_state(status, kv->get_base()));
180-
state_swa .reset(new llama_kv_cache_unified_state(status, kv->get_swa ()));
167+
llama_kv_cache_unified_iswa * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS) {
168+
state_base = kv->get_base()->init_full();
169+
state_swa = kv->get_swa ()->init_full();
170+
171+
status = llama_memory_status_combine(state_base->get_status(), state_swa->get_status());
172+
}
173+
174+
llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
175+
llama_kv_cache_unified_iswa * kv,
176+
llama_context * lctx,
177+
bool optimize) : status(LLAMA_MEMORY_STATUS_SUCCESS) {
178+
state_base = kv->get_base()->init_update(lctx, optimize);
179+
state_swa = kv->get_swa ()->init_update(lctx, optimize);
180+
181+
status = llama_memory_status_combine(state_base->get_status(), state_swa->get_status());
181182
}
182183

183184
llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
184-
llama_memory_status status,
185185
llama_kv_cache_unified_iswa * kv,
186186
llama_sbatch sbatch,
187187
std::vector<uint32_t> heads_base,
188188
std::vector<uint32_t> heads_swa,
189189
std::vector<llama_ubatch> ubatches)
190-
: status(status),
191-
sbatch(std::move(sbatch)),
192-
ubatches(std::move(ubatches)) {
193-
// note: here we copy the ubatches. not sure if this is ideal
194-
state_base.reset(new llama_kv_cache_unified_state(status, kv->get_base(), {}, std::move(heads_base), this->ubatches));
195-
state_swa .reset(new llama_kv_cache_unified_state(status, kv->get_swa (), {}, std::move(heads_swa), this->ubatches));
196-
}
190+
: status(LLAMA_MEMORY_STATUS_SUCCESS),
191+
sbatch(std::move(sbatch)),
192+
ubatches(std::move(ubatches)) {
193+
// note: here we copy the ubatches. not sure if this is ideal
194+
state_base.reset(new llama_kv_cache_unified_state(kv->get_base(), {}, std::move(heads_base), this->ubatches));
195+
state_swa .reset(new llama_kv_cache_unified_state(kv->get_swa (), {}, std::move(heads_swa), this->ubatches));
196+
197+
status = llama_memory_status_combine(state_base->get_status(), state_swa->get_status());
198+
}
197199

198200
llama_kv_cache_unified_iswa_state:: ~llama_kv_cache_unified_iswa_state() = default;
199201

@@ -233,17 +235,18 @@ llama_memory_status llama_kv_cache_unified_iswa_state::get_status() const {
233235

234236
const llama_ubatch & llama_kv_cache_unified_iswa_state::get_ubatch() const {
235237
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
238+
236239
return ubatches[i_next];
237240
}
238241

239242
const llama_kv_cache_unified_state * llama_kv_cache_unified_iswa_state::get_base() const {
240243
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
241244

242-
return state_base.get();
245+
return static_cast<const llama_kv_cache_unified_state *>(state_base.get());
243246
}
244247

245248
const llama_kv_cache_unified_state * llama_kv_cache_unified_iswa_state::get_swa() const {
246249
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
247250

248-
return state_swa.get();
251+
return static_cast<const llama_kv_cache_unified_state *>(state_swa.get());
249252
}

src/llama-kv-cache-unified-iswa.h

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,7 @@ class llama_kv_cache_unified_iswa : public llama_kv_cache {
5454

5555
llama_memory_state_ptr init_full() override;
5656

57-
bool update(llama_context & lctx) override;
58-
59-
void defrag_sched(float thold) override;
57+
llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) override;
6058

6159
bool get_can_shift() const override;
6260

@@ -86,12 +84,16 @@ class llama_kv_cache_unified_iswa_state : public llama_memory_state_i {
8684

8785
// used to create a full-cache state
8886
llama_kv_cache_unified_iswa_state(
89-
llama_memory_status status,
9087
llama_kv_cache_unified_iswa * kv);
9188

89+
// used to create an update state
90+
llama_kv_cache_unified_iswa_state(
91+
llama_kv_cache_unified_iswa * kv,
92+
llama_context * lctx,
93+
bool optimize);
94+
9295
// used to create a state from a batch
9396
llama_kv_cache_unified_iswa_state(
94-
llama_memory_status status,
9597
llama_kv_cache_unified_iswa * kv,
9698
llama_sbatch sbatch,
9799
std::vector<uint32_t> heads_base,
@@ -120,7 +122,7 @@ class llama_kv_cache_unified_iswa_state : public llama_memory_state_i {
120122
const llama_kv_cache_unified_state * get_swa() const;
121123

122124
private:
123-
const llama_memory_status status;
125+
llama_memory_status status;
124126

125127
//llama_kv_cache_unified_iswa * kv;
126128

@@ -131,6 +133,6 @@ class llama_kv_cache_unified_iswa_state : public llama_memory_state_i {
131133

132134
std::vector<llama_ubatch> ubatches;
133135

134-
std::unique_ptr<llama_kv_cache_unified_state> state_base;
135-
std::unique_ptr<llama_kv_cache_unified_state> state_swa;
136+
llama_memory_state_ptr state_base;
137+
llama_memory_state_ptr state_swa;
136138
};

0 commit comments

Comments
 (0)