Skip to content

Commit 28ee6d2

Browse files
committed
graph : add kv_idxs to unified_iswa input [no ci]
1 parent 5f87f28 commit 28ee6d2

File tree

2 files changed

+15
-2
lines changed

2 files changed

+15
-2
lines changed

src/llama-graph.cpp

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,11 @@ void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
291291
}
292292

293293
void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch) {
294+
if (self_kv_idxs) {
295+
kv_state->get_base()->set_input_kv_idxs(self_kv_idxs, ubatch);
296+
kv_state->get_swa ()->set_input_kv_idxs(self_kv_idxs, ubatch);
297+
}
298+
294299
if (self_kq_mask) {
295300
kv_state->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
296301
}
@@ -1287,8 +1292,10 @@ ggml_tensor * llm_graph_context::build_attn(
12871292

12881293
// store to KV cache
12891294
{
1290-
ggml_build_forward_expand(gf, kv_state->cpy_k(ctx0, k_cur, nullptr, il));
1291-
ggml_build_forward_expand(gf, kv_state->cpy_v(ctx0, v_cur, nullptr, il));
1295+
const auto & kv_idxs = inp->get_kv_idxs();
1296+
1297+
ggml_build_forward_expand(gf, kv_state->cpy_k(ctx0, k_cur, kv_idxs, il));
1298+
ggml_build_forward_expand(gf, kv_state->cpy_v(ctx0, v_cur, kv_idxs, il));
12921299
}
12931300

12941301
const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
@@ -1428,6 +1435,9 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
14281435
{
14291436
const auto n_kv = kv_state->get_base()->get_n_kv();
14301437

1438+
inp->self_kv_idxs = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
1439+
ggml_set_input(inp->self_kv_idxs);
1440+
14311441
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
14321442
//cb(inp->self_kq_mask, "KQ_mask", -1);
14331443
ggml_set_input(inp->self_kq_mask);

src/llama-graph.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,9 +276,12 @@ class llm_graph_input_attn_kv_unified_iswa : public llm_graph_input_i {
276276

277277
void set_input(const llama_ubatch * ubatch) override;
278278

279+
ggml_tensor * get_kv_idxs() const { return self_kv_idxs; }
279280
ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
280281
ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; }
281282

283+
ggml_tensor * self_kv_idxs = nullptr; // I32 [n_batch]
284+
282285
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch]
283286
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch]
284287
ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch]

0 commit comments

Comments
 (0)