Skip to content

Commit efe0bc9

Browse files
committed
kv-cache : refactor update mechanism
ggml-ci
1 parent 71e74a3 commit efe0bc9

10 files changed

+263
-175
lines changed

src/llama-context.cpp

Lines changed: 39 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -429,30 +429,49 @@ const llama_kv_cache * llama_context::get_kv_self() const {
429429
return kv_self;
430430
}
431431

432-
bool llama_context::kv_self_update() {
432+
void llama_context::kv_self_defrag_sched() {
433+
if (!memory) {
434+
return;
435+
}
436+
437+
memory_force_optimize = true;
438+
}
439+
440+
bool llama_context::kv_self_update(bool optimize) {
433441
if (!memory) {
434442
return false;
435443
}
436444

437445
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
438446

439-
if (!kv_self->update(*this)) {
440-
// no updates have been performed
441-
return false;
447+
{
448+
// TODO: remove in the future
449+
optimize |= memory_force_optimize;
450+
memory_force_optimize = false;
451+
452+
const auto kv_state = kv_self->init_update(this, optimize);
453+
if (kv_state->get_status() == LLAMA_MEMORY_STATUS_NO_UPDATE) {
454+
// no updates have been performed
455+
return false;
456+
}
457+
458+
if (!kv_state->apply()) {
459+
LLAMA_LOG_ERROR("%s: failed to apply memory update\n", __func__);
460+
}
442461
}
443462

444463
// if the KV cache did any computation, we have to reserve a new worst-case graph
445464
const auto kv_state = kv_self->init_full();
446465
if (!kv_state) {
447-
throw std::runtime_error("failed to initialize KV cache");
466+
throw std::runtime_error("failed to initialize memory state");
448467
}
449468

450469
const uint32_t n_seqs = cparams.n_seq_max;
451470
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
452471

453472
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, kv_state.get());
454473
if (!gf) {
455-
LLAMA_LOG_ERROR("%s: failed to reserve graph after the KV cache update\n", __func__);
474+
LLAMA_LOG_ERROR("%s: failed to reserve graph after the memory update\n", __func__);
456475
}
457476

458477
return true;
@@ -940,13 +959,13 @@ int llama_context::decode(llama_batch & inp_batch) {
940959
n_outputs_all = 1;
941960
}
942961

962+
bool did_optimize = false;
963+
943964
// handle any pending defrags/shifts
944-
kv_self_update();
965+
kv_self_update(false);
945966

946967
llama_memory_state_ptr kv_state;
947968

948-
bool did_defrag = false;
949-
950969
while (true) {
951970
kv_state = kv_self->init_batch(batch, cparams.n_ubatch, embd_pooled, /* logits_all */ n_outputs_all == n_tokens_all);
952971
if (!kv_state) {
@@ -957,14 +976,18 @@ int llama_context::decode(llama_batch & inp_batch) {
957976
case LLAMA_MEMORY_STATUS_SUCCESS:
958977
{
959978
} break;
979+
case LLAMA_MEMORY_STATUS_NO_UPDATE:
980+
{
981+
LLAMA_LOG_ERROR("%s: unexpected memory state status: %d\n", __func__, kv_state->get_status());
982+
return -2;
983+
}
960984
case LLAMA_MEMORY_STATUS_FAILED_PREPARE:
961985
{
962-
if (!did_defrag) {
963-
did_defrag = true;
986+
if (!did_optimize) {
987+
did_optimize = true;
964988

965-
kv_self->defrag_sched(-1.0f);
966-
if (kv_self_update()) {
967-
LLAMA_LOG_DEBUG("%s: failed to init batch of size %d, retrying after defrag\n", __func__, batch.n_tokens);
989+
if (kv_self_update(true)) {
990+
LLAMA_LOG_DEBUG("%s: retrying batch size %d after cache optimization\n", __func__, batch.n_tokens);
968991

969992
continue;
970993
}
@@ -1189,11 +1212,6 @@ int llama_context::decode(llama_batch & inp_batch) {
11891212
// wait for the computation to finish (automatically done when obtaining the model output)
11901213
//synchronize();
11911214

1192-
// decide if we need to defrag the kv cache
1193-
if (cparams.defrag_thold > 0.0f) {
1194-
kv_self->defrag_sched(cparams.defrag_thold);
1195-
}
1196-
11971215
// Reset state for the next token before backend sync, to allow the CPU activities in the reset to
11981216
// overlap with device computation.
11991217
ggml_backend_sched_reset(sched.get());
@@ -2283,7 +2301,7 @@ llama_kv_cache * llama_get_kv_self(llama_context * ctx) {
22832301

22842302
// deprecated
22852303
void llama_kv_self_update(llama_context * ctx) {
2286-
ctx->kv_self_update();
2304+
ctx->kv_self_update(false);
22872305
}
22882306

22892307
enum llama_pooling_type llama_pooling_type(const llama_context * ctx) {
@@ -2538,13 +2556,8 @@ llama_pos llama_kv_self_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) {
25382556

25392557
// deprecated
25402558
void llama_kv_self_defrag(llama_context * ctx) {
2541-
auto * kv = ctx->get_kv_self();
2542-
if (!kv) {
2543-
return;
2544-
}
2545-
25462559
// force defrag
2547-
kv->defrag_sched(-1.0f);
2560+
ctx->kv_self_defrag_sched();
25482561
}
25492562

25502563
bool llama_kv_self_can_shift(const llama_context * ctx) {

src/llama-context.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,8 @@ struct llama_context {
5252

5353
// return true of the KV cache was updated
5454
// TODO: remove
55-
bool kv_self_update();
55+
bool kv_self_update(bool optimize);
56+
void kv_self_defrag_sched();
5657

5758
enum llama_pooling_type pooling_type() const;
5859

@@ -231,6 +232,9 @@ struct llama_context {
231232

232233
std::unique_ptr<llama_memory_i> memory;
233234

235+
// TODO: temporary, until the llama_kv_self_defrag() API is removed
236+
bool memory_force_optimize = false;
237+
234238
// decode output (2-dimensional array: [n_outputs][n_vocab])
235239
size_t logits_size = 0; // capacity (of floats) for logits
236240
float * logits = nullptr;

src/llama-kv-cache-recurrent.cpp

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -386,6 +386,13 @@ llama_memory_state_ptr llama_kv_cache_recurrent::init_full() {
386386
return std::make_unique<llama_kv_cache_recurrent_state>(LLAMA_MEMORY_STATUS_SUCCESS, this);
387387
}
388388

389+
llama_memory_state_ptr llama_kv_cache_recurrent::init_update(llama_context * lctx, bool optimize) {
390+
GGML_UNUSED(lctx);
391+
GGML_UNUSED(optimize);
392+
393+
return std::make_unique<llama_kv_cache_recurrent_state>(LLAMA_MEMORY_STATUS_NO_UPDATE);
394+
}
395+
389396
bool llama_kv_cache_recurrent::prepare(const std::vector<llama_ubatch> & ubatches) {
390397
// simply remember the full state because it is very small for this type of cache
391398
// TODO: optimize
@@ -419,17 +426,6 @@ bool llama_kv_cache_recurrent::prepare(const std::vector<llama_ubatch> & ubatche
419426
return success;
420427
}
421428

422-
bool llama_kv_cache_recurrent::update(llama_context & lctx) {
423-
GGML_UNUSED(lctx);
424-
// noop
425-
return false;
426-
}
427-
428-
void llama_kv_cache_recurrent::defrag_sched(float thold) {
429-
GGML_UNUSED(thold);
430-
// noop
431-
}
432-
433429
bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
434430
const uint32_t n_tokens = ubatch.n_tokens;
435431
const uint32_t n_seqs = ubatch.n_seqs;

src/llama-kv-cache-recurrent.h

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,7 @@ class llama_kv_cache_recurrent : public llama_kv_cache {
5252

5353
llama_memory_state_ptr init_full() override;
5454

55-
bool update(llama_context & lctx) override;
56-
57-
void defrag_sched(float thold) override;
55+
llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) override;
5856

5957
bool prepare(const std::vector<llama_ubatch> & ubatches);
6058

src/llama-kv-cache-unified-iswa.cpp

Lines changed: 37 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -123,26 +123,16 @@ llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(const llama_batch
123123

124124
assert(heads_base.size() == heads_swa.size());
125125

126-
return std::make_unique<llama_kv_cache_unified_iswa_state>(LLAMA_MEMORY_STATUS_SUCCESS,
126+
return std::make_unique<llama_kv_cache_unified_iswa_state>(
127127
this, std::move(sbatch), std::move(heads_base), std::move(heads_swa), std::move(ubatches));
128128
}
129129

130130
llama_memory_state_ptr llama_kv_cache_unified_iswa::init_full() {
131-
return std::make_unique<llama_kv_cache_unified_iswa_state>(LLAMA_MEMORY_STATUS_SUCCESS, this);
131+
return std::make_unique<llama_kv_cache_unified_iswa_state>(this);
132132
}
133133

134-
bool llama_kv_cache_unified_iswa::update(llama_context & lctx) {
135-
bool res = false;
136-
137-
res = res | kv_base->update(lctx);
138-
res = res | kv_swa ->update(lctx);
139-
140-
return res;
141-
}
142-
143-
void llama_kv_cache_unified_iswa::defrag_sched(float thold) {
144-
kv_base->defrag_sched(thold);
145-
kv_swa ->defrag_sched(thold);
134+
llama_memory_state_ptr llama_kv_cache_unified_iswa::init_update(llama_context * lctx, bool optimize) {
135+
return std::make_unique<llama_kv_cache_unified_iswa_state>(this, lctx, optimize);
146136
}
147137

148138
bool llama_kv_cache_unified_iswa::get_can_shift() const {
@@ -174,25 +164,48 @@ llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_swa() const {
174164
llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(llama_memory_status status) : status(status) {}
175165

176166
llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
177-
llama_memory_status status,
178-
llama_kv_cache_unified_iswa * kv) : status(status) {
179-
state_base.reset(new llama_kv_cache_unified_state(status, kv->get_base()));
180-
state_swa .reset(new llama_kv_cache_unified_state(status, kv->get_swa ()));
167+
llama_kv_cache_unified_iswa * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS) {
168+
state_base = kv->get_base()->init_full();
169+
state_swa = kv->get_swa ()->init_full();
170+
}
171+
172+
llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
173+
llama_kv_cache_unified_iswa * kv,
174+
llama_context * lctx,
175+
bool optimize) : status(LLAMA_MEMORY_STATUS_SUCCESS) {
176+
state_base = kv->get_base()->init_update(lctx, optimize);
177+
state_swa = kv->get_swa ()->init_update(lctx, optimize);
178+
179+
// TODO: this is very ugly - how to make it simpler?
180+
// the llama_memory_status enum is not very well designed
181+
if (state_base->get_status() != LLAMA_MEMORY_STATUS_SUCCESS && state_base->get_status() != LLAMA_MEMORY_STATUS_NO_UPDATE) {
182+
status = state_base->get_status();
183+
return;
184+
}
185+
186+
if (state_swa->get_status() != LLAMA_MEMORY_STATUS_SUCCESS && state_swa->get_status() != LLAMA_MEMORY_STATUS_NO_UPDATE) {
187+
status = state_swa->get_status();
188+
return;
189+
}
190+
191+
if (state_base->get_status() == LLAMA_MEMORY_STATUS_NO_UPDATE && state_swa->get_status() == LLAMA_MEMORY_STATUS_NO_UPDATE) {
192+
status = LLAMA_MEMORY_STATUS_NO_UPDATE;
193+
return;
194+
}
181195
}
182196

183197
llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
184-
llama_memory_status status,
185198
llama_kv_cache_unified_iswa * kv,
186199
llama_sbatch sbatch,
187200
std::vector<uint32_t> heads_base,
188201
std::vector<uint32_t> heads_swa,
189202
std::vector<llama_ubatch> ubatches)
190-
: status(status),
203+
: status(LLAMA_MEMORY_STATUS_SUCCESS),
191204
sbatch(std::move(sbatch)),
192205
ubatches(std::move(ubatches)) {
193206
// note: here we copy the ubatches. not sure if this is ideal
194-
state_base.reset(new llama_kv_cache_unified_state(status, kv->get_base(), {}, std::move(heads_base), this->ubatches));
195-
state_swa .reset(new llama_kv_cache_unified_state(status, kv->get_swa (), {}, std::move(heads_swa), this->ubatches));
207+
state_base.reset(new llama_kv_cache_unified_state(kv->get_base(), {}, std::move(heads_base), this->ubatches));
208+
state_swa .reset(new llama_kv_cache_unified_state(kv->get_swa (), {}, std::move(heads_swa), this->ubatches));
196209
}
197210

198211
llama_kv_cache_unified_iswa_state:: ~llama_kv_cache_unified_iswa_state() = default;
@@ -239,11 +252,11 @@ const llama_ubatch & llama_kv_cache_unified_iswa_state::get_ubatch() const {
239252
const llama_kv_cache_unified_state * llama_kv_cache_unified_iswa_state::get_base() const {
240253
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
241254

242-
return state_base.get();
255+
return static_cast<const llama_kv_cache_unified_state *>(state_base.get());
243256
}
244257

245258
const llama_kv_cache_unified_state * llama_kv_cache_unified_iswa_state::get_swa() const {
246259
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
247260

248-
return state_swa.get();
261+
return static_cast<const llama_kv_cache_unified_state *>(state_swa.get());
249262
}

src/llama-kv-cache-unified-iswa.h

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,7 @@ class llama_kv_cache_unified_iswa : public llama_kv_cache {
5454

5555
llama_memory_state_ptr init_full() override;
5656

57-
bool update(llama_context & lctx) override;
58-
59-
void defrag_sched(float thold) override;
57+
llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) override;
6058

6159
bool get_can_shift() const override;
6260

@@ -86,12 +84,16 @@ class llama_kv_cache_unified_iswa_state : public llama_memory_state_i {
8684

8785
// used to create a full-cache state
8886
llama_kv_cache_unified_iswa_state(
89-
llama_memory_status status,
9087
llama_kv_cache_unified_iswa * kv);
9188

89+
// used to create an update state
90+
llama_kv_cache_unified_iswa_state(
91+
llama_kv_cache_unified_iswa * kv,
92+
llama_context * lctx,
93+
bool optimize);
94+
9295
// used to create a state from a batch
9396
llama_kv_cache_unified_iswa_state(
94-
llama_memory_status status,
9597
llama_kv_cache_unified_iswa * kv,
9698
llama_sbatch sbatch,
9799
std::vector<uint32_t> heads_base,
@@ -120,7 +122,7 @@ class llama_kv_cache_unified_iswa_state : public llama_memory_state_i {
120122
const llama_kv_cache_unified_state * get_swa() const;
121123

122124
private:
123-
const llama_memory_status status;
125+
llama_memory_status status;
124126

125127
//llama_kv_cache_unified_iswa * kv;
126128

@@ -131,6 +133,6 @@ class llama_kv_cache_unified_iswa_state : public llama_memory_state_i {
131133

132134
std::vector<llama_ubatch> ubatches;
133135

134-
std::unique_ptr<llama_kv_cache_unified_state> state_base;
135-
std::unique_ptr<llama_kv_cache_unified_state> state_swa;
136+
llama_memory_state_ptr state_base;
137+
llama_memory_state_ptr state_swa;
136138
};

0 commit comments

Comments
 (0)