Skip to content

Commit db0cd69

Browse files
committed
kv-cache : support non-FA case
ggml-ci
1 parent 28ee6d2 commit db0cd69

File tree

1 file changed

+14
-5
lines changed

1 file changed

+14
-5
lines changed

src/llama-kv-cache-unified.cpp

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -777,23 +777,32 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_
777777

778778
if (!v_trans) {
779779
if (kv_idxs) {
780-
return ggml_set_rows(ctx, v, ggml_reshape_2d(ctx, v_cur, v->ne[0], n_tokens), kv_idxs);
780+
return ggml_set_rows(ctx, v, v_cur, kv_idxs);
781781
}
782782

783783
v_view = ggml_view_1d(ctx, v,
784784
n_tokens*hparams.n_embd_v_gqa(il),
785785
ggml_row_size(v->type, hparams.n_embd_v_gqa(il))*head_cur);
786786
} else {
787+
v_cur = ggml_transpose(ctx, v_cur);
788+
789+
// note: the V cache is transposed when not using flash attention
787790
if (kv_idxs) {
788-
GGML_ABORT("TODO: implement kv_idxs for transposed V cache -- for now use flash attention");
791+
// the row becomes a single element and we repeat the KV indices d_head times
792+
// TODO: this seems not very optimal - can we do something better?
793+
v_view = ggml_reshape_3d(ctx, v, 1, v->ne[1], v->ne[0]);
794+
795+
v_cur = ggml_cont(ctx, v_cur);
796+
v_cur = ggml_reshape_3d(ctx, v_cur, 1, n_tokens, hparams.n_embd_v_gqa(il));
797+
798+
kv_idxs = ggml_repeat_4d(ctx, kv_idxs, v_cur->ne[1], v_cur->ne[2], 1, 1);
799+
800+
return ggml_set_rows(ctx, v_view, v_cur, kv_idxs);
789801
}
790802

791-
// note: the V cache is transposed when not using flash attention
792803
v_view = ggml_view_2d(ctx, v, n_tokens, hparams.n_embd_v_gqa(il),
793804
(v->ne[1])*ggml_element_size(v),
794805
(head_cur)*ggml_element_size(v));
795-
796-
v_cur = ggml_transpose(ctx, v_cur);
797806
}
798807

799808
return ggml_cpy(ctx, v_cur, v_view);

0 commit comments

Comments
 (0)