Skip to content

Commit b6bdfd3

Browse files
committed
kv-cache : hide padding details in the implementation
ggml-ci
1 parent 1b53231 commit b6bdfd3

File tree

4 files changed

+26
-23
lines changed

4 files changed

+26
-23
lines changed

src/llama-context.cpp

Lines changed: 3 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1274,24 +1274,10 @@ int llama_context::decode(llama_batch & inp_batch) {
12741274
}
12751275

12761276
// find KV slot
1277-
{
1278-
if (!kv_self->find_slot(ubatch)) {
1279-
LLAMA_LOG_WARN("%s: failed to find KV cache slot for ubatch of size %d\n", __func__, ubatch.n_tokens);
1280-
1281-
return 1;
1282-
}
1277+
if (!kv_self->find_slot(ubatch)) {
1278+
LLAMA_LOG_WARN("%s: failed to find KV cache slot for ubatch of size %d\n", __func__, ubatch.n_tokens);
12831279

1284-
if (!is_recurrent) {
1285-
auto * kv = static_cast<llama_kv_cache_unified *>(kv_self);
1286-
1287-
// a heuristic, to avoid attending the full cache if it is not yet utilized
1288-
// after enough generations, the benefit from this heuristic disappears
1289-
// if we start defragmenting the cache, the benefit from this will be more important
1290-
const uint32_t pad = kv->get_padding(cparams);
1291-
kv->n = std::min(kv->size, std::max(pad, GGML_PAD(kv->cell_max(), pad)));
1292-
1293-
//printf("kv.n = %5d, kv.used = %5d, kv.head = %5d\n", kv->n, kv->used, kv->head);
1294-
}
1280+
return 1;
12951281
}
12961282

12971283
ggml_backend_sched_reset(sched.get());

src/llama-kv-cache.cpp

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,17 @@ llama_kv_cache_unified::llama_kv_cache_unified(
2121
ggml_type type_k,
2222
ggml_type type_v,
2323
bool v_trans,
24-
uint32_t kv_size) : hparams(hparams), cbs(std::move(cbs)), v_trans(v_trans) {
24+
uint32_t kv_size,
25+
uint32_t padding) : hparams(hparams), cbs(std::move(cbs)), v_trans(v_trans), padding(padding) {
2526
const int32_t n_layer = hparams.n_layer;
2627

2728
has_shift = false;
2829
can_shift = true;
2930

30-
LLAMA_LOG_INFO("%s: kv_size = %d, type_k = '%s', type_v = '%s', n_layer = %d, can_shift = %d\n",
31-
__func__, kv_size, ggml_type_name(type_k), ggml_type_name(type_v), n_layer, can_shift);
31+
LLAMA_LOG_INFO("%s: kv_size = %d, type_k = '%s', type_v = '%s', n_layer = %d, can_shift = %d, padding = %d\n",
32+
__func__, kv_size, ggml_type_name(type_k), ggml_type_name(type_v), n_layer, can_shift, padding);
33+
34+
GGML_ASSERT(kv_size % padding == 0 && "kv_size must be a multiple of padding");
3235

3336
head = 0;
3437
size = kv_size;
@@ -463,6 +466,13 @@ bool llama_kv_cache_unified::find_slot(
463466

464467
pending.ranges.push_back({head, head + n_tokens});
465468

469+
// a heuristic, to avoid attending the full cache if it is not yet utilized
470+
// after enough generations, the benefit from this heuristic disappears
471+
// if we start defragmenting the cache, the benefit from this will be more important
472+
n = std::min(size, std::max(padding, GGML_PAD(cell_max(), padding)));
473+
474+
//printf("n = %5d, used = %5d, head = %5d\n", n, used, head);
475+
466476
return true;
467477
}
468478

src/llama-kv-cache.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,8 @@ class llama_kv_cache_unified : public llama_kv_cache {
101101
ggml_type type_k,
102102
ggml_type type_v,
103103
bool v_trans,
104-
uint32_t kv_size);
104+
uint32_t kv_size,
105+
uint32_t padding);
105106

106107
~llama_kv_cache_unified() = default;
107108

@@ -196,6 +197,9 @@ class llama_kv_cache_unified : public llama_kv_cache {
196197
// computed before each graph build
197198
uint32_t n = 0;
198199

200+
// required padding
201+
uint32_t padding = 1;
202+
199203
std::vector<llama_kv_cell> cells;
200204

201205
std::vector<ggml_tensor *> k_l; // per layer

src/llama-model.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12808,7 +12808,9 @@ llama_memory_i * llama_model::create_memory(llama_cparams & cparams, const llama
1280812808
} break;
1280912809
default:
1281012810
{
12811-
cparams.n_ctx = GGML_PAD(cparams.n_ctx, llama_kv_cache_unified::get_padding(cparams));
12811+
const auto padding = llama_kv_cache_unified::get_padding(cparams);
12812+
12813+
cparams.n_ctx = GGML_PAD(cparams.n_ctx, padding);
1281212814

1281312815
LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);
1281412816

@@ -12832,7 +12834,8 @@ llama_memory_i * llama_model::create_memory(llama_cparams & cparams, const llama
1283212834
params.type_k,
1283312835
params.type_v,
1283412836
!cparams.flash_attn,
12835-
cparams.n_ctx);
12837+
cparams.n_ctx,
12838+
padding);
1283612839
}
1283712840
}
1283812841

0 commit comments

Comments
 (0)