@@ -821,17 +821,21 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_
821
821
return ggml_set_rows (ctx, v, v_cur, kv_idxs);
822
822
}
823
823
824
- // note: the V cache is transposed when not using flash attention
825
- v_cur = ggml_transpose (ctx, v_cur);
826
-
827
- // the row becomes a single element and we repeat the KV indices d_head times
824
+ // the row becomes a single element
828
825
ggml_tensor * v_view = ggml_reshape_3d (ctx, v, 1 , v->ne [1 ], v->ne [0 ]);
829
826
830
- v_cur = ggml_cont_3d (ctx, v_cur, 1 , v_cur->ne [0 ], v_cur->ne [1 ]);
827
+ // note: the V cache is transposed when not using flash attention
828
+ v_cur = ggml_permute (ctx, ggml_reshape_3d (ctx, v_cur, v_cur->ne [0 ], 1 , v_cur->ne [1 ]), 2 , 0 , 1 , 3 );
831
829
832
- // TODO: this repeat can be avoided if ggml_set_rows() supports broadcast
833
- kv_idxs = ggml_repeat_4d (ctx, kv_idxs, v_cur->ne [1 ], v_cur->ne [2 ], 1 , 1 );
830
+ // note: we can be more explicit here at the cost of extra cont
831
+ // however, above we take advantage that a row of single element is always contiguous regardless of the row stride
832
+ // v_cur = ggml_transpose(ctx, v_cur);
833
+ // v_cur = ggml_cont_3d(ctx, v_cur, 1, v_cur->ne[0], v_cur->ne[1]);
834
834
835
+ // we broadcast the KV indices n_embd_v_gqa times
836
+ // v [1, n_kv, n_embd_v_gqa]
837
+ // v_cur [1, n_tokens, n_embd_v_gqa]
838
+ // kv_idxs [n_tokens, 1, 1]
835
839
return ggml_set_rows (ctx, v_view, v_cur, kv_idxs);
836
840
}
837
841
0 commit comments