Skip to content

Commit 61710fc

Browse files
committed
kv-cache : hide defrag logic in the implementation
ggml-ci
1 parent bb1c81c commit 61710fc

File tree

4 files changed

+21
-20
lines changed

4 files changed

+21
-20
lines changed

src/llama-context.cpp

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1426,19 +1426,8 @@ int llama_context::decode(llama_batch & inp_batch) {
14261426
//synchronize();
14271427

14281428
// decide if we need to defrag the kv cache
1429-
if (!llama_model_is_recurrent(&model) && cparams.causal_attn && cparams.defrag_thold > 0.0f) {
1430-
auto * kv = static_cast<llama_kv_cache_unified *>(kv_self);
1431-
1432-
// - do not defrag small contexts (i.e. < 2048 tokens)
1433-
// - count the padding towards the number of used tokens
1434-
const float fragmentation = kv->n >= 2048 ? std::max(0.0f, 1.0f - float(kv->used + kv->padding)/float(kv->n)) : 0.0f;
1435-
1436-
// queue defragmentation for next llama_kv_cache_update
1437-
if (fragmentation > cparams.defrag_thold) {
1438-
LLAMA_LOG_DEBUG("%s: fragmentation: %.2f - requesting defrag\n", __func__, fragmentation);
1439-
1440-
kv_self->defrag();
1441-
}
1429+
if (cparams.defrag_thold > 0.0f) {
1430+
kv_self->defrag(cparams.defrag_thold);
14421431
}
14431432

14441433
// Reset state for the next token before backend sync, to allow the CPU activities in the reset to
@@ -2588,7 +2577,8 @@ void llama_kv_self_defrag(llama_context * ctx) {
25882577
return;
25892578
}
25902579

2591-
return kv->defrag();
2580+
// force defrag
2581+
return kv->defrag(-1.0f);
25922582
}
25932583

25942584
// deprecated

src/llama-kv-cache.cpp

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -357,8 +357,17 @@ llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const {
357357
return result;
358358
}
359359

360-
void llama_kv_cache_unified::defrag() {
361-
do_defrag = true;
360+
void llama_kv_cache_unified::defrag(float thold) {
361+
// - do not defrag small contexts (i.e. < 2048 tokens)
362+
// - count the padding towards the number of used tokens
363+
const float fragmentation = n >= 2048 ? std::max(0.0f, 1.0f - float(used + padding)/float(n)) : 0.0f;
364+
365+
// queue defragmentation for next llama_kv_cache_update
366+
if (fragmentation > thold) {
367+
LLAMA_LOG_DEBUG("%s: fragmentation: %.2f - requesting defrag\n", __func__, fragmentation);
368+
369+
do_defrag = true;
370+
}
362371
}
363372

364373
void llama_kv_cache_unified::restore() {
@@ -1358,7 +1367,8 @@ llama_pos llama_kv_cache_recurrent::seq_pos_max(llama_seq_id seq_id) const {
13581367
return result;
13591368
}
13601369

1361-
void llama_kv_cache_recurrent::defrag() {
1370+
void llama_kv_cache_recurrent::defrag(float thold) {
1371+
GGML_UNUSED(thold);
13621372
LLAMA_LOG_ERROR("%s: not supported\n", __func__);
13631373
}
13641374

src/llama-kv-cache.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ struct llama_kv_cache : public llama_memory_i {
3131
virtual void restore() = 0; // call if batch processing fails - restores the cache state
3232
virtual void commit() = 0; // call after successful batch processing - clears any pending state
3333

34+
virtual void defrag(float thold) = 0;
35+
3436
virtual int32_t get_n_tokens() const = 0;
3537
virtual int32_t get_used_cells() const = 0; // TODO: remove, this is too-specific to the unified cache
3638

@@ -124,7 +126,7 @@ class llama_kv_cache_unified : public llama_kv_cache {
124126
llama_pos get_pos_max() const override;
125127

126128
void clear() override;
127-
void defrag() override;
129+
void defrag(float thold) override;
128130

129131
void restore() override;
130132
void commit() override;
@@ -252,7 +254,7 @@ class llama_kv_cache_recurrent : public llama_kv_cache {
252254
llama_pos get_pos_max() const override;
253255

254256
void clear() override;
255-
void defrag() override;
257+
void defrag(float thold) override;
256258

257259
void restore() override;
258260
void commit() override;

src/llama-memory.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ class llama_memory_i {
2323
virtual ~llama_memory_i() = default;
2424

2525
virtual void clear() = 0;
26-
virtual void defrag() = 0;
2726

2827
virtual bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) = 0;
2928
virtual void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) = 0;

0 commit comments

Comments
 (0)