Skip to content

Commit c4273b8

Browse files
committed
kv-cache : utilize ggml_set_rows broadcast
ggml-ci
1 parent 53327f4 commit c4273b8

File tree

1 file changed

+11
-7
lines changed

1 file changed

+11
-7
lines changed

src/llama-kv-cache-unified.cpp

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -821,17 +821,21 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_
821821
return ggml_set_rows(ctx, v, v_cur, kv_idxs);
822822
}
823823

824-
// note: the V cache is transposed when not using flash attention
825-
v_cur = ggml_transpose(ctx, v_cur);
826-
827-
// the row becomes a single element and we repeat the KV indices d_head times
824+
// the row becomes a single element
828825
ggml_tensor * v_view = ggml_reshape_3d(ctx, v, 1, v->ne[1], v->ne[0]);
829826

830-
v_cur = ggml_cont_3d(ctx, v_cur, 1, v_cur->ne[0], v_cur->ne[1]);
827+
// note: the V cache is transposed when not using flash attention
828+
v_cur = ggml_permute(ctx, ggml_reshape_3d(ctx, v_cur, v_cur->ne[0], 1, v_cur->ne[1]), 2, 0, 1, 3);
831829

832-
// TODO: this repeat can be avoided if ggml_set_rows() supports broadcast
833-
kv_idxs = ggml_repeat_4d(ctx, kv_idxs, v_cur->ne[1], v_cur->ne[2], 1, 1);
830+
// note: we can be more explicit here at the cost of extra cont
831+
// however, above we take advantage that a row of single element is always contiguous regardless of the row stride
832+
//v_cur = ggml_transpose(ctx, v_cur);
833+
//v_cur = ggml_cont_3d(ctx, v_cur, 1, v_cur->ne[0], v_cur->ne[1]);
834834

835+
// we broadcast the KV indices n_embd_v_gqa times
836+
// v [1, n_kv, n_embd_v_gqa]
837+
// v_cur [1, n_tokens, n_embd_v_gqa]
838+
// kv_idxs [n_tokens, 1, 1]
835839
return ggml_set_rows(ctx, v_view, v_cur, kv_idxs);
836840
}
837841

0 commit comments

Comments
 (0)