Skip to content

Commit 93864cd

Browse files
committed
llama : experimental DeepSeek2 MLA implementation that caches latent kv representations
1 parent 6369f86 commit 93864cd

File tree

3 files changed

+106
-16
lines changed

3 files changed

+106
-16
lines changed

src/llama-kv-cache.cpp

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ bool llama_kv_cache_init(
5353
auto it = ctx_map.find(buft);
5454
if (it == ctx_map.end()) {
5555
struct ggml_init_params params = {
56-
/*.mem_size =*/ size_t(2u*n_layer*ggml_tensor_overhead()),
56+
/*.mem_size =*/ size_t(4u*n_layer*ggml_tensor_overhead()),
5757
/*.mem_buffer =*/ NULL,
5858
/*.no_alloc =*/ true,
5959
};
@@ -71,6 +71,10 @@ bool llama_kv_cache_init(
7171
cache.k_l.reserve(n_layer);
7272
cache.v_l.reserve(n_layer);
7373

74+
// DeepSeek MLA
75+
cache.kr_l.reserve(n_layer);
76+
cache.kv_l.reserve(n_layer);
77+
7478
for (int i = 0; i < n_layer; i++) {
7579
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s();
7680
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s();
@@ -97,6 +101,16 @@ bool llama_kv_cache_init(
97101
ggml_format_name(v, "cache_v_l%d", i);
98102
cache.k_l.push_back(k);
99103
cache.v_l.push_back(v);
104+
105+
// DeepSeek MLA
106+
const uint32_t n_embd_head_qk_rope = hparams.n_rot;
107+
const uint32_t kv_lora_rank = hparams.n_lora_kv;
108+
ggml_tensor * kr = ggml_new_tensor_1d(ctx, cache.type_kr, n_embd_head_qk_rope*kv_size);
109+
ggml_tensor * kv = ggml_new_tensor_1d(ctx, cache.type_kv, kv_lora_rank*kv_size);
110+
ggml_format_name(kr, "cache_kr_l%d", i);
111+
ggml_format_name(kv, "cache_kv_l%d", i);
112+
cache.kr_l.push_back(kr);
113+
cache.kv_l.push_back(kv);
100114
}
101115

102116
// allocate tensors and initialize the buffers to avoid NaNs in the padding

src/llama-kv-cache.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,11 +49,18 @@ struct llama_kv_cache {
4949
ggml_type type_k = GGML_TYPE_F16;
5050
ggml_type type_v = GGML_TYPE_F16;
5151

52+
ggml_type type_kr = GGML_TYPE_F32;
53+
ggml_type type_kv = GGML_TYPE_F32;
54+
5255
std::vector<llama_kv_cell> cells;
5356

5457
std::vector<struct ggml_tensor *> k_l; // per layer
5558
std::vector<struct ggml_tensor *> v_l;
5659

60+
// DeepSeek MLA
61+
std::vector<struct ggml_tensor *> kr_l; // per layer
62+
std::vector<struct ggml_tensor *> kv_l;
63+
5764
std::vector<ggml_context_ptr> ctxs;
5865
std::vector<ggml_backend_buffer_ptr> bufs;
5966

src/llama.cpp

Lines changed: 84 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -8860,32 +8860,37 @@ struct llm_build_context {
88608860
LLM_NORM_RMS, cb, il);
88618861
cb(kv_compressed, "kv_compressed", il);
88628862

8863+
struct ggml_tensor * kv_cache_view = ggml_view_1d(ctx0, kv_self.kv_l[il], n_tokens*kv_lora_rank, ggml_row_size(kv_self.kv_l[il]->type, kv_lora_rank)*kv_head);
8864+
cb(kv_cache_view, "kv_cache_view", il);
8865+
8866+
// note: storing c^KV in the KV cache
8867+
ggml_build_forward_expand(gf, ggml_cpy(ctx0, kv_compressed, kv_cache_view));
8868+
8869+
struct ggml_tensor * kv_cache =
8870+
ggml_view_2d(ctx0, kv_self.kv_l[il],
8871+
kv_lora_rank, n_kv,
8872+
ggml_row_size(kv_self.kv_l[il]->type, kv_lora_rank),
8873+
0);
8874+
cb(kv_cache, "kv_cache", il);
8875+
88638876
// {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)} * {kv_lora_rank, n_tokens} -> {n_head * (n_embd_head_qk_nope + n_embd_head_v), n_tokens}
8864-
struct ggml_tensor * kv = ggml_mul_mat(ctx0, model.layers[il].wkv_b, kv_compressed);
8877+
struct ggml_tensor * kv = ggml_mul_mat(ctx0, model.layers[il].wkv_b, kv_cache);
88658878
cb(kv, "kv", il);
88668879

88678880
// split into {n_head * n_embd_head_qk_nope, n_tokens}
8868-
struct ggml_tensor * k_nope = ggml_view_3d(ctx0, kv, n_embd_head_qk_nope, n_head, n_tokens,
8881+
struct ggml_tensor * k_nope = ggml_view_3d(ctx0, kv, n_embd_head_qk_nope, n_head, n_kv,
88698882
ggml_row_size(kv->type, n_embd_head_qk_nope + hparams.n_embd_head_v),
88708883
ggml_row_size(kv->type, n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)),
88718884
0);
88728885
cb(k_nope, "k_nope", il);
88738886

88748887
// and {n_head * n_embd_head_v, n_tokens}
8875-
struct ggml_tensor * v_states = ggml_view_3d(ctx0, kv, hparams.n_embd_head_v, n_head, n_tokens,
8888+
struct ggml_tensor * v_states = ggml_view_3d(ctx0, kv, hparams.n_embd_head_v, n_head, n_kv,
88768889
ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)),
88778890
ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)*n_head),
88788891
ggml_row_size(kv->type, (n_embd_head_qk_nope)));
88798892
cb(v_states, "v_states", il);
88808893

8881-
v_states = ggml_cont(ctx0, v_states);
8882-
cb(v_states, "v_states", il);
8883-
8884-
v_states = ggml_view_2d(ctx0, v_states, hparams.n_embd_head_v * n_head, n_tokens,
8885-
ggml_row_size(kv->type, hparams.n_embd_head_v * n_head),
8886-
0);
8887-
cb(v_states, "v_states", il);
8888-
88898894
q_pe = ggml_cont(ctx0, q_pe); // TODO: the CUDA backend does not support non-contiguous RoPE
88908895
q_pe = ggml_rope_ext(
88918896
ctx0, q_pe, inp_pos, nullptr,
@@ -8903,15 +8908,61 @@ struct llm_build_context {
89038908
);
89048909
cb(k_pe, "k_pe", il);
89058910

8911+
struct ggml_tensor * kr_cache_view = ggml_view_1d(ctx0, kv_self.kr_l[il], n_tokens*n_embd_head_qk_rope, ggml_row_size(kv_self.kr_l[il]->type, n_embd_head_qk_rope)*kv_head);
8912+
cb(kr_cache_view, "kr_cache_view", il);
8913+
8914+
// note: storing RoPE-ed version of K^R in the KV cache
8915+
ggml_build_forward_expand(gf, ggml_cpy(ctx0, k_pe, kr_cache_view));
8916+
89068917
struct ggml_tensor * q_states = ggml_concat(ctx0, q_nope, q_pe, 0);
89078918
cb(q_states, "q_states", il);
89088919

8909-
struct ggml_tensor * k_states = ggml_concat(ctx0, k_nope, ggml_repeat(ctx0, k_pe, q_pe), 0);
8920+
struct ggml_tensor * kr_cache =
8921+
ggml_view_2d(ctx0, kv_self.kr_l[il],
8922+
n_embd_head_qk_rope, n_kv,
8923+
ggml_row_size(kv_self.kr_l[il]->type, n_embd_head_qk_rope),
8924+
0);
8925+
cb(kr_cache, "kr_cache", il);
8926+
8927+
// TODO is there a better way?
8928+
struct ggml_tensor * kr_rep_shape = ggml_new_tensor_3d(ctx0, kr_cache->type, kr_cache->ne[0], kr_cache->ne[1], n_head);
8929+
struct ggml_tensor * kr_rep = ggml_repeat(ctx0, kr_cache, kr_rep_shape);
8930+
kr_rep = ggml_permute(ctx0, kr_rep, 0, 2, 1, 3);
8931+
struct ggml_tensor * k_states = ggml_concat(ctx0, k_nope, kr_rep, 0);
89108932
cb(k_states, "k_states", il);
89118933

8912-
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
8913-
model.layers[il].wo, NULL,
8914-
k_states, v_states, q_states, KQ_mask, n_tokens, kv_head, n_kv, kq_scale, cb, il);
8934+
q_states = ggml_permute(ctx0, q_states, 0, 2, 1, 3);
8935+
cb(q_states, "q_states", il);
8936+
8937+
k_states = ggml_permute(ctx0, k_states, 0, 2, 1, 3);
8938+
cb(k_states, "k_states", il);
8939+
8940+
struct ggml_tensor * kq = ggml_mul_mat(ctx0, k_states, q_states);
8941+
cb(kq, "kq", il);
8942+
8943+
kq = ggml_soft_max_ext(ctx0, kq, KQ_mask, kq_scale, hparams.f_max_alibi_bias);
8944+
cb(kq, "kq_soft_max_ext", il);
8945+
8946+
v_states = ggml_permute(ctx0, v_states, 1, 2, 0, 3);
8947+
cb(v_states, "v_states", il);
8948+
8949+
v_states = ggml_cont(ctx0, v_states);
8950+
8951+
struct ggml_tensor * kqv = ggml_mul_mat(ctx0, v_states, kq);
8952+
cb(kqv, "kqv", il);
8953+
8954+
GGML_ASSERT(kv_self.size == n_ctx);
8955+
8956+
struct ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
8957+
cb(kqv_merged, "kqv_merged", il);
8958+
8959+
cur = ggml_cont_2d(ctx0, kqv_merged, n_embd_head_v*n_head, n_tokens);
8960+
cb(cur, "kqv_merged_cont", il);
8961+
8962+
ggml_build_forward_expand(gf, cur);
8963+
8964+
cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wo, cur);
8965+
cb(cur, "kqv_out", il);
89158966
}
89168967

89178968
if (il == n_layer - 1) {
@@ -12004,6 +12055,24 @@ struct llama_context * llama_new_context_with_model(
1200412055
ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
1200512056
}
1200612057

12058+
{
12059+
size_t memory_size_kr = 0;
12060+
size_t memory_size_kv = 0;
12061+
12062+
for (auto & kr : ctx->kv_self.kr_l) {
12063+
memory_size_kr += ggml_nbytes(kr);
12064+
}
12065+
12066+
for (auto & kv : ctx->kv_self.kv_l) {
12067+
memory_size_kv += ggml_nbytes(kv);
12068+
}
12069+
12070+
LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, K^R (%s): %7.2f MiB, c^KV (%s): %7.2f MiB\n", __func__,
12071+
(float)(memory_size_kr + memory_size_kv) / (1024.0f * 1024.0f),
12072+
ggml_type_name(type_k), (float)memory_size_kr / (1024.0f * 1024.0f),
12073+
ggml_type_name(type_k), (float)memory_size_kv / (1024.0f * 1024.0f));
12074+
}
12075+
1200712076
// graph outputs buffer
1200812077
{
1200912078
// resized during inference when a batch uses more outputs

0 commit comments

Comments
 (0)