@@ -777,23 +777,32 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_
777
777
778
778
if (!v_trans) {
779
779
if (kv_idxs) {
780
- return ggml_set_rows (ctx, v, ggml_reshape_2d (ctx, v_cur, v-> ne [ 0 ], n_tokens) , kv_idxs);
780
+ return ggml_set_rows (ctx, v, v_cur, kv_idxs);
781
781
}
782
782
783
783
v_view = ggml_view_1d (ctx, v,
784
784
n_tokens*hparams.n_embd_v_gqa (il),
785
785
ggml_row_size (v->type , hparams.n_embd_v_gqa (il))*head_cur);
786
786
} else {
787
+ v_cur = ggml_transpose (ctx, v_cur);
788
+
789
+ // note: the V cache is transposed when not using flash attention
787
790
if (kv_idxs) {
788
- GGML_ABORT (" TODO: implement kv_idxs for transposed V cache -- for now use flash attention" );
791
+ // the row becomes a single element and we repeat the KV indices d_head times
792
+ // TODO: this seems not very optimal - can we do something better?
793
+ v_view = ggml_reshape_3d (ctx, v, 1 , v->ne [1 ], v->ne [0 ]);
794
+
795
+ v_cur = ggml_cont (ctx, v_cur);
796
+ v_cur = ggml_reshape_3d (ctx, v_cur, 1 , n_tokens, hparams.n_embd_v_gqa (il));
797
+
798
+ kv_idxs = ggml_repeat_4d (ctx, kv_idxs, v_cur->ne [1 ], v_cur->ne [2 ], 1 , 1 );
799
+
800
+ return ggml_set_rows (ctx, v_view, v_cur, kv_idxs);
789
801
}
790
802
791
- // note: the V cache is transposed when not using flash attention
792
803
v_view = ggml_view_2d (ctx, v, n_tokens, hparams.n_embd_v_gqa (il),
793
804
(v->ne [1 ])*ggml_element_size (v),
794
805
(head_cur)*ggml_element_size (v));
795
-
796
- v_cur = ggml_transpose (ctx, v_cur);
797
806
}
798
807
799
808
return ggml_cpy (ctx, v_cur, v_view);
0 commit comments