@@ -291,6 +291,11 @@ void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
291
291
}
292
292
293
293
void llm_graph_input_attn_kv_unified_iswa::set_input (const llama_ubatch * ubatch) {
294
+ if (self_kv_idxs) {
295
+ kv_state->get_base ()->set_input_kv_idxs (self_kv_idxs, ubatch);
296
+ kv_state->get_swa ()->set_input_kv_idxs (self_kv_idxs, ubatch);
297
+ }
298
+
294
299
if (self_kq_mask) {
295
300
kv_state->get_base ()->set_input_kq_mask (self_kq_mask, ubatch, cparams.causal_attn );
296
301
}
@@ -1287,8 +1292,10 @@ ggml_tensor * llm_graph_context::build_attn(
1287
1292
1288
1293
// store to KV cache
1289
1294
{
1290
- ggml_build_forward_expand (gf, kv_state->cpy_k (ctx0, k_cur, nullptr , il));
1291
- ggml_build_forward_expand (gf, kv_state->cpy_v (ctx0, v_cur, nullptr , il));
1295
+ const auto & kv_idxs = inp->get_kv_idxs ();
1296
+
1297
+ ggml_build_forward_expand (gf, kv_state->cpy_k (ctx0, k_cur, kv_idxs, il));
1298
+ ggml_build_forward_expand (gf, kv_state->cpy_v (ctx0, v_cur, kv_idxs, il));
1292
1299
}
1293
1300
1294
1301
const auto & kq_mask = is_swa ? inp->get_kq_mask_swa () : inp->get_kq_mask ();
@@ -1428,6 +1435,9 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
1428
1435
{
1429
1436
const auto n_kv = kv_state->get_base ()->get_n_kv ();
1430
1437
1438
+ inp->self_kv_idxs = ggml_new_tensor_1d (ctx0, GGML_TYPE_I32, n_tokens);
1439
+ ggml_set_input (inp->self_kv_idxs );
1440
+
1431
1441
inp->self_kq_mask = ggml_new_tensor_2d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD));
1432
1442
// cb(inp->self_kq_mask, "KQ_mask", -1);
1433
1443
ggml_set_input (inp->self_kq_mask );
0 commit comments