Skip to content

Commit a540bcd

Browse files
committed
kv-cache : add ubatch_next()
ggml-ci
1 parent b6bdfd3 commit a540bcd

File tree

3 files changed

+27
-17
lines changed

3 files changed

+27
-17
lines changed

src/llama-context.cpp

Lines changed: 2 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1239,22 +1239,7 @@ int llama_context::decode(llama_batch & inp_batch) {
12391239
int64_t n_outputs_prev = 0;
12401240

12411241
while (sbatch.n_tokens > 0) {
1242-
llama_ubatch ubatch = llama_ubatch();
1243-
1244-
const auto & n_ubatch = cparams.n_ubatch;
1245-
1246-
if (is_recurrent) {
1247-
if (embd_pooled) {
1248-
// Pooled embeddings cannot be split across ubatches (yet)
1249-
ubatch = sbatch.split_seq(cparams.n_ubatch);
1250-
} else {
1251-
// recurrent model architectures are easier to implement
1252-
// with equal-length sequences
1253-
ubatch = sbatch.split_equal(cparams.n_ubatch);
1254-
}
1255-
} else {
1256-
ubatch = sbatch.split_simple(n_ubatch);
1257-
}
1242+
llama_ubatch ubatch = kv_self->ubatch_next(sbatch, cparams.n_ubatch, embd_pooled);
12581243

12591244
// count the outputs in this u_batch
12601245
{
@@ -1424,7 +1409,7 @@ int llama_context::decode(llama_batch & inp_batch) {
14241409

14251410
// - do not defrag small contexts (i.e. < 2048 tokens)
14261411
// - count the padding towards the number of used tokens
1427-
const float fragmentation = kv->n >= 2048 ? std::max(0.0f, 1.0f - float(kv->used + kv->get_padding(cparams))/float(kv->n)) : 0.0f;
1412+
const float fragmentation = kv->n >= 2048 ? std::max(0.0f, 1.0f - float(kv->used + kv->padding)/float(kv->n)) : 0.0f;
14281413

14291414
// queue defragmentation for next llama_kv_cache_update
14301415
if (fragmentation > cparams.defrag_thold) {

src/llama-kv-cache.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -476,6 +476,14 @@ bool llama_kv_cache_unified::find_slot(
476476
return true;
477477
}
478478

479+
llama_ubatch llama_kv_cache_unified::ubatch_next(
480+
llama_sbatch & sbatch,
481+
uint32_t n_ubatch,
482+
bool embd_pooled) const {
483+
GGML_UNUSED(embd_pooled);
484+
return sbatch.split_simple(n_ubatch);
485+
}
486+
479487
uint32_t llama_kv_cache_unified::get_padding(const llama_cparams & cparams) {
480488
// the FA kernels require padding to avoid extra runtime boundary checks
481489
return cparams.flash_attn ? 256u : 32u;
@@ -1539,6 +1547,15 @@ bool llama_kv_cache_recurrent::find_slot(
15391547
return n >= n_seqs;
15401548
}
15411549

1550+
llama_ubatch llama_kv_cache_recurrent::ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const {
1551+
if (embd_pooled) {
1552+
// Pooled embeddings cannot be split across ubatches (yet)
1553+
return sbatch.split_seq(n_ubatch);
1554+
}
1555+
1556+
return sbatch.split_equal(n_ubatch);
1557+
}
1558+
15421559
uint32_t llama_kv_cache_recurrent::cell_max() const {
15431560
for (uint32_t i = size; i > 0; --i) {
15441561
const llama_kv_cell & cell = cells[i - 1];

src/llama-kv-cache.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
struct llama_cparams;
1414
struct llama_hparams;
1515
struct llama_ubatch;
16+
struct llama_sbatch;
1617

1718
struct llama_kv_cache : public llama_memory_i {
1819
// can be used to query data from the model if needed
@@ -44,6 +45,9 @@ struct llama_kv_cache : public llama_memory_i {
4445

4546
virtual bool find_slot(const llama_ubatch & batch) = 0;
4647

48+
// different KV caches require different batch splitting strategies
49+
virtual llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const = 0;
50+
4751
// simulate full cache, used for allocating worst-case compute buffers
4852
virtual void set_full() = 0;
4953

@@ -139,6 +143,8 @@ class llama_kv_cache_unified : public llama_kv_cache {
139143
// to the first cell of the slot.
140144
bool find_slot(const llama_ubatch & batch) override;
141145

146+
llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const override;
147+
142148
static uint32_t get_padding(const llama_cparams & cparams);
143149

144150
// find how many cells are currently in use
@@ -263,6 +269,8 @@ class llama_kv_cache_recurrent : public llama_kv_cache {
263269
// to the first cell of the slot.
264270
bool find_slot(const llama_ubatch & batch) override;
265271

272+
llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const override;
273+
266274
// find how many cells are currently in use
267275
uint32_t cell_max() const;
268276

0 commit comments

Comments
 (0)