Skip to content

Commit 3197a6e

Browse files
committed
Revert "kv-cache : refactor the update/defrag mechanism (ggml-org#13988)"
This reverts commit 3e63a58.
1 parent e49a9ab commit 3197a6e

11 files changed

+187
-336
lines changed

src/llama-context.cpp

Lines changed: 27 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -429,62 +429,30 @@ const llama_kv_cache * llama_context::get_kv_self() const {
429429
return kv_self;
430430
}
431431

432-
void llama_context::kv_self_defrag_sched() {
433-
if (!memory) {
434-
return;
435-
}
436-
437-
memory_force_optimize = true;
438-
}
439-
440-
bool llama_context::kv_self_update(bool optimize) {
432+
bool llama_context::kv_self_update() {
441433
if (!memory) {
442434
return false;
443435
}
444436

445437
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
446438

447-
{
448-
// TODO: remove in the future
449-
optimize |= memory_force_optimize;
450-
memory_force_optimize = false;
451-
452-
const auto kv_state = kv_self->init_update(this, optimize);
453-
switch (kv_state->get_status()) {
454-
case LLAMA_MEMORY_STATUS_SUCCESS:
455-
{
456-
// noop
457-
} break;
458-
case LLAMA_MEMORY_STATUS_NO_UPDATE:
459-
{
460-
// no updates need to be performed
461-
return false;
462-
}
463-
case LLAMA_MEMORY_STATUS_FAILED_PREPARE:
464-
case LLAMA_MEMORY_STATUS_FAILED_COMPUTE:
465-
{
466-
LLAMA_LOG_ERROR("%s: failed to prepare memory update\n", __func__);
467-
return false;
468-
}
469-
}
470-
471-
if (!kv_state->apply()) {
472-
LLAMA_LOG_ERROR("%s: failed to apply memory update\n", __func__);
473-
}
439+
if (!kv_self->update(*this)) {
440+
// no updates have been performed
441+
return false;
474442
}
475443

476444
// if the KV cache did any computation, we have to reserve a new worst-case graph
477445
const auto kv_state = kv_self->init_full();
478446
if (!kv_state) {
479-
throw std::runtime_error("failed to initialize memory state");
447+
throw std::runtime_error("failed to initialize KV cache");
480448
}
481449

482450
const uint32_t n_seqs = cparams.n_seq_max;
483451
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
484452

485453
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, kv_state.get());
486454
if (!gf) {
487-
LLAMA_LOG_ERROR("%s: failed to reserve graph after the memory update\n", __func__);
455+
LLAMA_LOG_ERROR("%s: failed to reserve graph after the KV cache update\n", __func__);
488456
}
489457

490458
return true;
@@ -972,13 +940,13 @@ int llama_context::decode(llama_batch & inp_batch) {
972940
n_outputs_all = 1;
973941
}
974942

975-
bool did_optimize = false;
976-
977943
// handle any pending defrags/shifts
978-
kv_self_update(false);
944+
kv_self_update();
979945

980946
llama_memory_state_ptr kv_state;
981947

948+
bool did_defrag = false;
949+
982950
while (true) {
983951
kv_state = kv_self->init_batch(batch, cparams.n_ubatch, embd_pooled, /* logits_all */ n_outputs_all == n_tokens_all);
984952
if (!kv_state) {
@@ -989,32 +957,25 @@ int llama_context::decode(llama_batch & inp_batch) {
989957
case LLAMA_MEMORY_STATUS_SUCCESS:
990958
{
991959
} break;
992-
case LLAMA_MEMORY_STATUS_NO_UPDATE:
993-
{
994-
LLAMA_LOG_ERROR("%s: unexpected memory state status: %d\n", __func__, kv_state->get_status());
995-
996-
return -2;
997-
}
998960
case LLAMA_MEMORY_STATUS_FAILED_PREPARE:
999961
{
1000-
if (!did_optimize) {
1001-
did_optimize = true;
962+
if (!did_defrag) {
963+
did_defrag = true;
1002964

1003-
if (kv_self_update(true)) {
1004-
LLAMA_LOG_DEBUG("%s: retrying batch size %d after cache optimization\n", __func__, batch.n_tokens);
965+
kv_self->defrag_sched(-1.0f);
966+
if (kv_self_update()) {
967+
LLAMA_LOG_DEBUG("%s: failed to init batch of size %d, retrying after defrag\n", __func__, batch.n_tokens);
1005968

1006969
continue;
1007970
}
1008971
}
1009972

1010-
LLAMA_LOG_WARN("%s: failed to find a memory slot for batch of size %d\n", __func__, batch.n_tokens);
973+
LLAMA_LOG_WARN("%s: failed to find KV cache slot for batch of size %d\n", __func__, batch.n_tokens);
1011974

1012975
return 1;
1013976
}
1014977
case LLAMA_MEMORY_STATUS_FAILED_COMPUTE:
1015978
{
1016-
LLAMA_LOG_ERROR("%s: compute failed while preparing batch of size %d\n", __func__, batch.n_tokens);
1017-
1018979
return -2;
1019980
}
1020981
}
@@ -1231,6 +1192,11 @@ int llama_context::decode(llama_batch & inp_batch) {
12311192
// wait for the computation to finish (automatically done when obtaining the model output)
12321193
//synchronize();
12331194

1195+
// decide if we need to defrag the kv cache
1196+
if (cparams.defrag_thold > 0.0f) {
1197+
kv_self->defrag_sched(cparams.defrag_thold);
1198+
}
1199+
12341200
// Reset state for the next token before backend sync, to allow the CPU activities in the reset to
12351201
// overlap with device computation.
12361202
ggml_backend_sched_reset(sched.get());
@@ -2320,7 +2286,7 @@ llama_kv_cache * llama_get_kv_self(llama_context * ctx) {
23202286

23212287
// deprecated
23222288
void llama_kv_self_update(llama_context * ctx) {
2323-
ctx->kv_self_update(false);
2289+
ctx->kv_self_update();
23242290
}
23252291

23262292
enum llama_pooling_type llama_pooling_type(const llama_context * ctx) {
@@ -2575,8 +2541,13 @@ llama_pos llama_kv_self_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) {
25752541

25762542
// deprecated
25772543
void llama_kv_self_defrag(llama_context * ctx) {
2544+
auto * kv = ctx->get_kv_self();
2545+
if (!kv) {
2546+
return;
2547+
}
2548+
25782549
// force defrag
2579-
ctx->kv_self_defrag_sched();
2550+
kv->defrag_sched(-1.0f);
25802551
}
25812552

25822553
bool llama_kv_self_can_shift(const llama_context * ctx) {

src/llama-context.h

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,7 @@ struct llama_context {
5252

5353
// return true of the KV cache was updated
5454
// TODO: remove
55-
bool kv_self_update(bool optimize);
56-
void kv_self_defrag_sched();
55+
bool kv_self_update();
5756

5857
enum llama_pooling_type pooling_type() const;
5958

@@ -232,9 +231,6 @@ struct llama_context {
232231

233232
std::unique_ptr<llama_memory_i> memory;
234233

235-
// TODO: temporary, until the llama_kv_self_defrag() API is removed
236-
bool memory_force_optimize = false;
237-
238234
// decode output (2-dimensional array: [n_outputs][n_vocab])
239235
size_t logits_size = 0; // capacity (of floats) for logits
240236
float * logits = nullptr;

src/llama-kv-cache-recurrent.cpp

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
#include "llama-kv-cache-recurrent.h"
22

33
#include "llama-impl.h"
4-
#include "llama-io.h"
54
#include "llama-batch.h"
65
#include "llama-model.h"
76

@@ -387,13 +386,6 @@ llama_memory_state_ptr llama_kv_cache_recurrent::init_full() {
387386
return std::make_unique<llama_kv_cache_recurrent_state>(LLAMA_MEMORY_STATUS_SUCCESS, this);
388387
}
389388

390-
llama_memory_state_ptr llama_kv_cache_recurrent::init_update(llama_context * lctx, bool optimize) {
391-
GGML_UNUSED(lctx);
392-
GGML_UNUSED(optimize);
393-
394-
return std::make_unique<llama_kv_cache_recurrent_state>(LLAMA_MEMORY_STATUS_NO_UPDATE);
395-
}
396-
397389
bool llama_kv_cache_recurrent::prepare(const std::vector<llama_ubatch> & ubatches) {
398390
// simply remember the full state because it is very small for this type of cache
399391
// TODO: optimize
@@ -427,6 +419,17 @@ bool llama_kv_cache_recurrent::prepare(const std::vector<llama_ubatch> & ubatche
427419
return success;
428420
}
429421

422+
bool llama_kv_cache_recurrent::update(llama_context & lctx) {
423+
GGML_UNUSED(lctx);
424+
// noop
425+
return false;
426+
}
427+
428+
void llama_kv_cache_recurrent::defrag_sched(float thold) {
429+
GGML_UNUSED(thold);
430+
// noop
431+
}
432+
430433
bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
431434
const uint32_t n_tokens = ubatch.n_tokens;
432435
const uint32_t n_seqs = ubatch.n_seqs;

src/llama-kv-cache-recurrent.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,9 @@ class llama_kv_cache_recurrent : public llama_kv_cache {
5252

5353
llama_memory_state_ptr init_full() override;
5454

55-
llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) override;
55+
bool update(llama_context & lctx) override;
56+
57+
void defrag_sched(float thold) override;
5658

5759
bool prepare(const std::vector<llama_ubatch> & ubatches);
5860

src/llama-kv-cache-unified-iswa.cpp

Lines changed: 28 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -123,16 +123,26 @@ llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(const llama_batch
123123

124124
assert(heads_base.size() == heads_swa.size());
125125

126-
return std::make_unique<llama_kv_cache_unified_iswa_state>(
126+
return std::make_unique<llama_kv_cache_unified_iswa_state>(LLAMA_MEMORY_STATUS_SUCCESS,
127127
this, std::move(sbatch), std::move(heads_base), std::move(heads_swa), std::move(ubatches));
128128
}
129129

130130
llama_memory_state_ptr llama_kv_cache_unified_iswa::init_full() {
131-
return std::make_unique<llama_kv_cache_unified_iswa_state>(this);
131+
return std::make_unique<llama_kv_cache_unified_iswa_state>(LLAMA_MEMORY_STATUS_SUCCESS, this);
132132
}
133133

134-
llama_memory_state_ptr llama_kv_cache_unified_iswa::init_update(llama_context * lctx, bool optimize) {
135-
return std::make_unique<llama_kv_cache_unified_iswa_state>(this, lctx, optimize);
134+
bool llama_kv_cache_unified_iswa::update(llama_context & lctx) {
135+
bool res = false;
136+
137+
res = res | kv_base->update(lctx);
138+
res = res | kv_swa ->update(lctx);
139+
140+
return res;
141+
}
142+
143+
void llama_kv_cache_unified_iswa::defrag_sched(float thold) {
144+
kv_base->defrag_sched(thold);
145+
kv_swa ->defrag_sched(thold);
136146
}
137147

138148
bool llama_kv_cache_unified_iswa::get_can_shift() const {
@@ -164,38 +174,26 @@ llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_swa() const {
164174
llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(llama_memory_status status) : status(status) {}
165175

166176
llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
167-
llama_kv_cache_unified_iswa * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS) {
168-
state_base = kv->get_base()->init_full();
169-
state_swa = kv->get_swa ()->init_full();
170-
171-
status = llama_memory_status_combine(state_base->get_status(), state_swa->get_status());
172-
}
173-
174-
llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
175-
llama_kv_cache_unified_iswa * kv,
176-
llama_context * lctx,
177-
bool optimize) : status(LLAMA_MEMORY_STATUS_SUCCESS) {
178-
state_base = kv->get_base()->init_update(lctx, optimize);
179-
state_swa = kv->get_swa ()->init_update(lctx, optimize);
180-
181-
status = llama_memory_status_combine(state_base->get_status(), state_swa->get_status());
177+
llama_memory_status status,
178+
llama_kv_cache_unified_iswa * kv) : status(status) {
179+
state_base.reset(new llama_kv_cache_unified_state(status, kv->get_base()));
180+
state_swa .reset(new llama_kv_cache_unified_state(status, kv->get_swa ()));
182181
}
183182

184183
llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
184+
llama_memory_status status,
185185
llama_kv_cache_unified_iswa * kv,
186186
llama_sbatch sbatch,
187187
std::vector<uint32_t> heads_base,
188188
std::vector<uint32_t> heads_swa,
189189
std::vector<llama_ubatch> ubatches)
190-
: status(LLAMA_MEMORY_STATUS_SUCCESS),
191-
sbatch(std::move(sbatch)),
192-
ubatches(std::move(ubatches)) {
193-
// note: here we copy the ubatches. not sure if this is ideal
194-
state_base.reset(new llama_kv_cache_unified_state(kv->get_base(), {}, std::move(heads_base), this->ubatches));
195-
state_swa .reset(new llama_kv_cache_unified_state(kv->get_swa (), {}, std::move(heads_swa), this->ubatches));
196-
197-
status = llama_memory_status_combine(state_base->get_status(), state_swa->get_status());
198-
}
190+
: status(status),
191+
sbatch(std::move(sbatch)),
192+
ubatches(std::move(ubatches)) {
193+
// note: here we copy the ubatches. not sure if this is ideal
194+
state_base.reset(new llama_kv_cache_unified_state(status, kv->get_base(), {}, std::move(heads_base), this->ubatches));
195+
state_swa .reset(new llama_kv_cache_unified_state(status, kv->get_swa (), {}, std::move(heads_swa), this->ubatches));
196+
}
199197

200198
llama_kv_cache_unified_iswa_state:: ~llama_kv_cache_unified_iswa_state() = default;
201199

@@ -235,18 +233,17 @@ llama_memory_status llama_kv_cache_unified_iswa_state::get_status() const {
235233

236234
const llama_ubatch & llama_kv_cache_unified_iswa_state::get_ubatch() const {
237235
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
238-
239236
return ubatches[i_next];
240237
}
241238

242239
const llama_kv_cache_unified_state * llama_kv_cache_unified_iswa_state::get_base() const {
243240
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
244241

245-
return static_cast<const llama_kv_cache_unified_state *>(state_base.get());
242+
return state_base.get();
246243
}
247244

248245
const llama_kv_cache_unified_state * llama_kv_cache_unified_iswa_state::get_swa() const {
249246
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
250247

251-
return static_cast<const llama_kv_cache_unified_state *>(state_swa.get());
248+
return state_swa.get();
252249
}

src/llama-kv-cache-unified-iswa.h

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,9 @@ class llama_kv_cache_unified_iswa : public llama_kv_cache {
5454

5555
llama_memory_state_ptr init_full() override;
5656

57-
llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) override;
57+
bool update(llama_context & lctx) override;
58+
59+
void defrag_sched(float thold) override;
5860

5961
bool get_can_shift() const override;
6062

@@ -84,16 +86,12 @@ class llama_kv_cache_unified_iswa_state : public llama_memory_state_i {
8486

8587
// used to create a full-cache state
8688
llama_kv_cache_unified_iswa_state(
89+
llama_memory_status status,
8790
llama_kv_cache_unified_iswa * kv);
8891

89-
// used to create an update state
90-
llama_kv_cache_unified_iswa_state(
91-
llama_kv_cache_unified_iswa * kv,
92-
llama_context * lctx,
93-
bool optimize);
94-
9592
// used to create a state from a batch
9693
llama_kv_cache_unified_iswa_state(
94+
llama_memory_status status,
9795
llama_kv_cache_unified_iswa * kv,
9896
llama_sbatch sbatch,
9997
std::vector<uint32_t> heads_base,
@@ -122,7 +120,7 @@ class llama_kv_cache_unified_iswa_state : public llama_memory_state_i {
122120
const llama_kv_cache_unified_state * get_swa() const;
123121

124122
private:
125-
llama_memory_status status;
123+
const llama_memory_status status;
126124

127125
//llama_kv_cache_unified_iswa * kv;
128126

@@ -133,6 +131,6 @@ class llama_kv_cache_unified_iswa_state : public llama_memory_state_i {
133131

134132
std::vector<llama_ubatch> ubatches;
135133

136-
llama_memory_state_ptr state_base;
137-
llama_memory_state_ptr state_swa;
134+
std::unique_ptr<llama_kv_cache_unified_state> state_base;
135+
std::unique_ptr<llama_kv_cache_unified_state> state_swa;
138136
};

0 commit comments

Comments
 (0)