Skip to content

Commit 5b7cc53

Browse files
committed
llama : add struct llama_kv_cache (wip) [no ci]
1 parent 924518e commit 5b7cc53

File tree

8 files changed

+428
-415
lines changed

8 files changed

+428
-415
lines changed

common/common.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -893,7 +893,9 @@ struct common_init_result common_init_from_params(common_params & params) {
893893
return iparams;
894894
}
895895

896-
if (params.ctx_shift && !llama_kv_cache_can_shift(lctx)) {
896+
llama_kv_cache * kv = llama_get_kv_cache(lctx);
897+
898+
if (params.ctx_shift && !llama_kv_cache_can_shift(kv)) {
897899
LOG_WRN("%s: KV cache shifting is not supported for this model, disabling KV cache shifting\n", __func__);
898900
params.ctx_shift = false;
899901
}
@@ -998,7 +1000,7 @@ struct common_init_result common_init_from_params(common_params & params) {
9981000
if (llama_model_has_decoder(model)) {
9991001
llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch)));
10001002
}
1001-
llama_kv_cache_clear(lctx);
1003+
llama_kv_cache_clear(kv);
10021004
llama_synchronize(lctx);
10031005
llama_perf_context_reset(lctx);
10041006
}

common/speculative.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -171,8 +171,10 @@ llama_tokens common_speculative_gen_draft(
171171
llama_tokens result;
172172
result.reserve(params.n_draft);
173173

174+
llama_kv_cache * kv = llama_get_kv_cache(ctx);
175+
174176
if (reuse_n == 0) {
175-
llama_kv_cache_clear(ctx);
177+
llama_kv_cache_clear(kv);
176178

177179
prompt.clear();
178180
} else {
@@ -191,14 +193,14 @@ llama_tokens common_speculative_gen_draft(
191193
}
192194

193195
if (reuse_i > 0) {
194-
llama_kv_cache_seq_rm (ctx, 0, 0, reuse_i);
195-
llama_kv_cache_seq_add(ctx, 0, reuse_i, -1, -reuse_i);
196+
llama_kv_cache_seq_rm (kv, 0, 0, reuse_i);
197+
llama_kv_cache_seq_add(kv, 0, reuse_i, -1, -reuse_i);
196198

197199
prompt.erase(prompt.begin(), prompt.begin() + reuse_i);
198200
}
199201

200202
if (reuse_n < (int) prompt.size()) {
201-
llama_kv_cache_seq_rm (ctx, 0, reuse_n, -1);
203+
llama_kv_cache_seq_rm (kv, 0, reuse_n, -1);
202204

203205
prompt.erase(prompt.begin() + reuse_n, prompt.end());
204206
}

examples/embedding/embedding.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,11 @@ static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & toke
3434

3535
static void batch_decode(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd, int embd_norm) {
3636
const enum llama_pooling_type pooling_type = llama_pooling_type(ctx);
37-
const struct llama_model * model = llama_get_model(ctx);
37+
const llama_model * model = llama_get_model(ctx);
38+
llama_kv_cache * kv = llama_get_kv_cache(ctx);
3839

3940
// clear previous kv_cache values (irrelevant for embeddings)
40-
llama_kv_cache_clear(ctx);
41+
llama_kv_cache_clear(kv);
4142

4243
// run model
4344
LOG_INF("%s: n_tokens = %d, n_seq = %d\n", __func__, batch.n_tokens, n_seq);

include/llama.h

Lines changed: 42 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ extern "C" {
6060
struct llama_model;
6161
struct llama_context;
6262
struct llama_sampler;
63+
struct llama_kv_cache;
6364

6465
typedef int32_t llama_pos;
6566
typedef int32_t llama_token;
@@ -460,8 +461,9 @@ extern "C" {
460461

461462
DEPRECATED(LLAMA_API int32_t llama_n_vocab (const struct llama_vocab * vocab), "use llama_vocab_n_tokens instead");
462463

463-
LLAMA_API const struct llama_model * llama_get_model (const struct llama_context * ctx);
464-
LLAMA_API enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx);
464+
LLAMA_API const struct llama_model * llama_get_model (const struct llama_context * ctx); // TODO: remove const?
465+
LLAMA_API struct llama_kv_cache * llama_get_kv_cache( struct llama_context * ctx);
466+
LLAMA_API enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx);
465467

466468
LLAMA_API const struct llama_vocab * llama_model_get_vocab(const struct llama_model * model);
467469
LLAMA_API enum llama_rope_type llama_model_rope_type(const struct llama_model * model);
@@ -576,7 +578,7 @@ extern "C" {
576578
// KV cache
577579
//
578580

579-
// TODO: remove llama_kv_cache_view_* API
581+
// TODO: start using struct llama_kv_cache
580582

581583
// Information associated with an individual cell in the KV cache view.
582584
struct llama_kv_cache_view_cell {
@@ -631,41 +633,47 @@ extern "C" {
631633

632634
// Returns the number of tokens in the KV cache (slow, use only for debug)
633635
// If a KV cell has multiple sequences assigned to it, it will be counted multiple times
634-
LLAMA_API int32_t llama_get_kv_cache_token_count(const struct llama_context * ctx);
636+
LLAMA_API int32_t llama_kv_cache_n_tokens(const struct llama_kv_cache * kv);
637+
638+
DEPRECATED(LLAMA_API int32_t llama_get_kv_cache_token_count(const struct llama_context * ctx),
639+
"use llama_kv_cache_n_tokens instead");
635640

636641
// Returns the number of used KV cells (i.e. have at least one sequence assigned to them)
637-
LLAMA_API int32_t llama_get_kv_cache_used_cells(const struct llama_context * ctx);
642+
LLAMA_API int32_t llama_kv_cache_used_cells(const struct llama_kv_cache * kv);
643+
644+
DEPRECATED(LLAMA_API int32_t llama_get_kv_cache_used_cells(const struct llama_context * ctx),
645+
"use llama_kv_cache_used_cells instead");
638646

639647
// Clear the KV cache - both cell info is erased and KV data is zeroed
640648
LLAMA_API void llama_kv_cache_clear(
641-
struct llama_context * ctx);
649+
struct llama_kv_cache * kv);
642650

643651
// Removes all tokens that belong to the specified sequence and have positions in [p0, p1)
644652
// Returns false if a partial sequence cannot be removed. Removing a whole sequence never fails
645653
// seq_id < 0 : match any sequence
646654
// p0 < 0 : [0, p1]
647655
// p1 < 0 : [p0, inf)
648656
LLAMA_API bool llama_kv_cache_seq_rm(
649-
struct llama_context * ctx,
650-
llama_seq_id seq_id,
651-
llama_pos p0,
652-
llama_pos p1);
657+
struct llama_kv_cache * kv,
658+
llama_seq_id seq_id,
659+
llama_pos p0,
660+
llama_pos p1);
653661

654662
// Copy all tokens that belong to the specified sequence to another sequence
655663
// Note that this does not allocate extra KV cache memory - it simply assigns the tokens to the new sequence
656664
// p0 < 0 : [0, p1]
657665
// p1 < 0 : [p0, inf)
658666
LLAMA_API void llama_kv_cache_seq_cp(
659-
struct llama_context * ctx,
660-
llama_seq_id seq_id_src,
661-
llama_seq_id seq_id_dst,
662-
llama_pos p0,
663-
llama_pos p1);
667+
struct llama_kv_cache * kv,
668+
llama_seq_id seq_id_src,
669+
llama_seq_id seq_id_dst,
670+
llama_pos p0,
671+
llama_pos p1);
664672

665673
// Removes all tokens that do not belong to the specified sequence
666674
LLAMA_API void llama_kv_cache_seq_keep(
667-
struct llama_context * ctx,
668-
llama_seq_id seq_id);
675+
struct llama_kv_cache * kv,
676+
llama_seq_id seq_id);
669677

670678
// Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1)
671679
// If the KV cache is RoPEd, the KV data is updated accordingly:
@@ -674,11 +682,11 @@ extern "C" {
674682
// p0 < 0 : [0, p1]
675683
// p1 < 0 : [p0, inf)
676684
LLAMA_API void llama_kv_cache_seq_add(
677-
struct llama_context * ctx,
678-
llama_seq_id seq_id,
679-
llama_pos p0,
680-
llama_pos p1,
681-
llama_pos delta);
685+
struct llama_kv_cache * kv,
686+
llama_seq_id seq_id,
687+
llama_pos p0,
688+
llama_pos p1,
689+
llama_pos delta);
682690

683691
// Integer division of the positions by factor of `d > 1`
684692
// If the KV cache is RoPEd, the KV data is updated accordingly:
@@ -687,31 +695,28 @@ extern "C" {
687695
// p0 < 0 : [0, p1]
688696
// p1 < 0 : [p0, inf)
689697
LLAMA_API void llama_kv_cache_seq_div(
690-
struct llama_context * ctx,
691-
llama_seq_id seq_id,
692-
llama_pos p0,
693-
llama_pos p1,
694-
int d);
698+
struct llama_kv_cache * kv,
699+
llama_seq_id seq_id,
700+
llama_pos p0,
701+
llama_pos p1,
702+
int d);
695703

696704
// Returns the largest position present in the KV cache for the specified sequence
697705
LLAMA_API llama_pos llama_kv_cache_seq_pos_max(
698-
struct llama_context * ctx,
699-
llama_seq_id seq_id);
700-
701-
// TODO: the llama_kv_cache_defrag and llama_kv_cache_update API tightly couples llama_context with llama_kv_cache
702-
// how to avoid this?
706+
struct llama_kv_cache * kv,
707+
llama_seq_id seq_id);
703708

704709
// Defragment the KV cache
705710
// This will be applied:
706711
// - lazily on next llama_decode()
707712
// - explicitly with llama_kv_cache_update()
708-
LLAMA_API void llama_kv_cache_defrag(struct llama_context * ctx);
709-
710-
// Apply the KV cache updates (such as K-shifts, defragmentation, etc.)
711-
LLAMA_API void llama_kv_cache_update(struct llama_context * ctx);
713+
LLAMA_API void llama_kv_cache_defrag(struct llama_kv_cache * kv);
712714

713715
// Check if the context supports KV cache shifting
714-
LLAMA_API bool llama_kv_cache_can_shift(struct llama_context * ctx);
716+
LLAMA_API bool llama_kv_cache_can_shift(const struct llama_kv_cache * kv);
717+
718+
// Apply the KV cache updates (such as K-shifts, defragmentation, etc.)
719+
LLAMA_API void llama_update_kv_cache(struct llama_context * ctx, struct llama_kv_cache * kv);
715720

716721
//
717722
// State / sessions

src/llama-context.cpp

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -602,11 +602,15 @@ uint32_t llama_n_seq_max(const struct llama_context * ctx) {
602602
return ctx->kv_self.size;
603603
}
604604

605-
const struct llama_model * llama_get_model(const struct llama_context * ctx) {
605+
const llama_model * llama_get_model(const llama_context * ctx) {
606606
return &ctx->model;
607607
}
608608

609-
enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx) {
609+
llama_kv_cache * llama_get_kv_cache(llama_context * ctx) {
610+
return &ctx->kv_self;
611+
}
612+
613+
enum llama_pooling_type llama_pooling_type(const llama_context * ctx) {
610614
return ctx->cparams.pooling_type;
611615
}
612616

@@ -1142,7 +1146,7 @@ struct llama_data_read {
11421146
if (dest_seq_id != -1) {
11431147
// single sequence
11441148

1145-
llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1);
1149+
kv_self.seq_rm(dest_seq_id, -1, -1);
11461150

11471151
llama_ubatch batch = ctx->sbatch.reserve_ubatch(cell_count, /* has_embd */ false);
11481152
batch.n_tokens = cell_count;
@@ -1185,7 +1189,7 @@ struct llama_data_read {
11851189
return false;
11861190
}
11871191

1188-
llama_kv_cache_clear(kv_self);
1192+
kv_self.clear();
11891193

11901194
for (uint32_t i = 0; i < cell_count; ++i) {
11911195
llama_kv_cell & cell = kv_self.cells[i];
@@ -1362,9 +1366,9 @@ struct llama_data_read {
13621366

13631367
if (!res) {
13641368
if (seq_id == -1) {
1365-
llama_kv_cache_clear(ctx);
1369+
ctx->kv_self.clear();
13661370
} else {
1367-
llama_kv_cache_seq_rm(ctx, seq_id, -1, -1);
1371+
ctx->kv_self.seq_rm(seq_id, -1, -1);
13681372
}
13691373
throw std::runtime_error("failed to restore kv cache");
13701374
}

0 commit comments

Comments
 (0)