Skip to content

Commit 8f1c5e3

Browse files
committed
kv-cache : use ggml_set_rows
ggml-ci
1 parent a8cd49b commit 8f1c5e3

File tree

4 files changed

+58
-16
lines changed

4 files changed

+58
-16
lines changed

src/llama-graph.cpp

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,10 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
281281
}
282282

283283
void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
284+
if (self_kv_idxs) {
285+
kv_state->set_input_kv_idxs(self_kv_idxs, ubatch);
286+
}
287+
284288
if (self_kq_mask) {
285289
kv_state->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
286290
}
@@ -1192,6 +1196,9 @@ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified()
11921196

11931197
const auto n_kv = kv_state->get_n_kv();
11941198

1199+
inp->self_kv_idxs = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
1200+
ggml_set_input(inp->self_kv_idxs);
1201+
11951202
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
11961203
//cb(inp->self_kq_mask, "KQ_mask", -1);
11971204
ggml_set_input(inp->self_kq_mask);
@@ -1224,8 +1231,10 @@ ggml_tensor * llm_graph_context::build_attn(
12241231

12251232
// store to KV cache
12261233
{
1227-
ggml_build_forward_expand(gf, kv_state->cpy_k(ctx0, k_cur, il));
1228-
ggml_build_forward_expand(gf, kv_state->cpy_v(ctx0, v_cur, il));
1234+
const auto & kv_idxs = inp->get_kv_idxs();
1235+
1236+
ggml_build_forward_expand(gf, kv_state->cpy_k(ctx0, k_cur, kv_idxs, il));
1237+
ggml_build_forward_expand(gf, kv_state->cpy_v(ctx0, v_cur, kv_idxs, il));
12291238
}
12301239

12311240
const auto & kq_mask = inp->get_kq_mask();
@@ -1278,8 +1287,8 @@ ggml_tensor * llm_graph_context::build_attn(
12781287

12791288
// store to KV cache
12801289
{
1281-
ggml_build_forward_expand(gf, kv_state->cpy_k(ctx0, k_cur, il));
1282-
ggml_build_forward_expand(gf, kv_state->cpy_v(ctx0, v_cur, il));
1290+
ggml_build_forward_expand(gf, kv_state->cpy_k(ctx0, k_cur, nullptr, il));
1291+
ggml_build_forward_expand(gf, kv_state->cpy_v(ctx0, v_cur, nullptr, il));
12831292
}
12841293

12851294
const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
@@ -1383,8 +1392,8 @@ ggml_tensor * llm_graph_context::build_attn(
13831392

13841393
// store to KV cache
13851394
{
1386-
ggml_build_forward_expand(gf, kv_state->cpy_k(ctx0, k_cur, il));
1387-
ggml_build_forward_expand(gf, kv_state->cpy_v(ctx0, v_cur, il));
1395+
ggml_build_forward_expand(gf, kv_state->cpy_k(ctx0, k_cur, nullptr, il));
1396+
ggml_build_forward_expand(gf, kv_state->cpy_v(ctx0, v_cur, nullptr, il));
13881397
}
13891398

13901399
const auto & kq_mask = inp->get_kq_mask();

src/llama-graph.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,8 +247,12 @@ class llm_graph_input_attn_kv_unified : public llm_graph_input_i {
247247

248248
void set_input(const llama_ubatch * ubatch) override;
249249

250+
ggml_tensor * get_kv_idxs() const { return self_kv_idxs; }
250251
ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
251252

253+
// TODO: should this be I64?
254+
ggml_tensor * self_kv_idxs = nullptr; // I32 [n_batch]
255+
252256
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch]
253257
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch]
254258

src/llama-kv-cache-unified.cpp

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -746,21 +746,25 @@ ggml_tensor * llama_kv_cache_unified::get_v(ggml_context * ctx, int32_t il, uint
746746
0);
747747
}
748748

749-
ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il, uint32_t head_cur) const {
749+
ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * kv_idxs, int32_t il, uint32_t head_cur) const {
750750
const int32_t ikv = map_layer_ids.at(il);
751751

752752
auto * k = layers[ikv].k;
753753

754754
const int64_t n_tokens = k_cur->ne[2];
755755

756+
if (kv_idxs) {
757+
return ggml_set_rows(ctx, k, ggml_reshape_2d(ctx, k_cur, k->ne[0], n_tokens), kv_idxs);
758+
}
759+
756760
ggml_tensor * k_view = ggml_view_1d(ctx, k,
757761
n_tokens*hparams.n_embd_k_gqa(il),
758762
ggml_row_size(k->type, hparams.n_embd_k_gqa(il))*head_cur);
759763

760764
return ggml_cpy(ctx, k_cur, k_view);
761765
}
762766

763-
ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il, uint32_t head_cur) const {
767+
ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * kv_idxs, int32_t il, uint32_t head_cur) const {
764768
const int32_t ikv = map_layer_ids.at(il);
765769

766770
auto * v = layers[ikv].v;
@@ -772,10 +776,18 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_
772776
ggml_tensor * v_view = nullptr;
773777

774778
if (!v_trans) {
779+
if (kv_idxs) {
780+
return ggml_set_rows(ctx, v, ggml_reshape_2d(ctx, v_cur, v->ne[0], n_tokens), kv_idxs);
781+
}
782+
775783
v_view = ggml_view_1d(ctx, v,
776784
n_tokens*hparams.n_embd_v_gqa(il),
777785
ggml_row_size(v->type, hparams.n_embd_v_gqa(il))*head_cur);
778786
} else {
787+
if (kv_idxs) {
788+
GGML_ABORT("TODO: implement kv_idxs for transposed V cache -- for now use flash attention");
789+
}
790+
779791
// note: the V cache is transposed when not using flash attention
780792
v_view = ggml_view_2d(ctx, v, n_tokens, hparams.n_embd_v_gqa(il),
781793
(v->ne[1])*ggml_element_size(v),
@@ -787,6 +799,17 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_
787799
return ggml_cpy(ctx, v_cur, v_view);
788800
}
789801

802+
void llama_kv_cache_unified::set_input_kv_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, uint32_t head_cur) const {
803+
const uint32_t n_tokens = ubatch->n_tokens;
804+
805+
GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
806+
int32_t * data = (int32_t *) dst->data;
807+
808+
for (uint32_t i = 0; i < n_tokens; ++i) {
809+
data[i] = head_cur + i;
810+
}
811+
}
812+
790813
void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
791814
const uint32_t n_tokens = ubatch->n_tokens;
792815

@@ -1789,18 +1812,22 @@ ggml_tensor * llama_kv_cache_unified_state::get_v(ggml_context * ctx, int32_t il
17891812
return kv->get_v(ctx, il, n_kv);
17901813
}
17911814

1792-
ggml_tensor * llama_kv_cache_unified_state::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const {
1793-
return kv->cpy_k(ctx, k_cur, il, head);
1815+
ggml_tensor * llama_kv_cache_unified_state::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * kv_idxs, int32_t il) const {
1816+
return kv->cpy_k(ctx, k_cur, kv_idxs, il, head);
17941817
}
17951818

1796-
ggml_tensor * llama_kv_cache_unified_state::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const {
1797-
return kv->cpy_v(ctx, v_cur, il, head);
1819+
ggml_tensor * llama_kv_cache_unified_state::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * kv_idxs, int32_t il) const {
1820+
return kv->cpy_v(ctx, v_cur, kv_idxs, il, head);
17981821
}
17991822

18001823
void llama_kv_cache_unified_state::set_input_k_shift(ggml_tensor * dst) const {
18011824
kv->set_input_k_shift(dst);
18021825
}
18031826

1827+
void llama_kv_cache_unified_state::set_input_kv_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const {
1828+
kv->set_input_kv_idxs(dst, ubatch, head);
1829+
}
1830+
18041831
void llama_kv_cache_unified_state::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
18051832
kv->set_input_kq_mask(dst, ubatch, causal_attn);
18061833
}

src/llama-kv-cache-unified.h

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -102,8 +102,8 @@ class llama_kv_cache_unified : public llama_memory_i {
102102
ggml_tensor * get_v(ggml_context * ctx, int32_t il, uint32_t n_kv) const;
103103

104104
// store k_cur and v_cur in the cache based on the provided head location
105-
ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il, uint32_t head_cur) const;
106-
ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il, uint32_t head_cur) const;
105+
ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * kv_idxs, int32_t il, uint32_t head_cur) const;
106+
ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * kv_idxs, int32_t il, uint32_t head_cur) const;
107107

108108
//
109109
// preparation API
@@ -126,6 +126,7 @@ class llama_kv_cache_unified : public llama_memory_i {
126126
// set_input API
127127
//
128128

129+
void set_input_kv_idxs (ggml_tensor * dst, const llama_ubatch * ubatch, uint32_t head_cur) const;
129130
void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const;
130131
void set_input_k_shift (ggml_tensor * dst) const;
131132
void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const;
@@ -257,11 +258,12 @@ class llama_kv_cache_unified_state : public llama_memory_state_i {
257258
ggml_tensor * get_v(ggml_context * ctx, int32_t il) const;
258259

259260
// store k_cur and v_cur in the cache based on the provided head location
260-
ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const;
261-
ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const;
261+
ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * kv_idxs, int32_t il) const;
262+
ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * kv_idxs, int32_t il) const;
262263

263264
void set_input_k_shift(ggml_tensor * dst) const;
264265

266+
void set_input_kv_idxs (ggml_tensor * dst, const llama_ubatch * ubatch) const;
265267
void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const;
266268
void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const;
267269

0 commit comments

Comments
 (0)