Skip to content

Commit 925af99

Browse files
committed
Simplified is_mla branch in llm_build_deepseek2()
1 parent a5df71e commit 925af99

File tree

1 file changed

+9
-9
lines changed

1 file changed

+9
-9
lines changed

src/llama-model.cpp

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10147,27 +10147,27 @@ struct llm_build_deepseek2 : public llm_graph_context {
1014710147
q_nope = ggml_permute(ctx0, q_nope, 0, 2, 1, 3);
1014810148
cb(q_nope, "q_nope_perm", il);
1014910149

10150+
// {n_embd_head_qk_nope, kv_lora_rank, n_head} x {n_embd_head_qk_nope, n_tokens, n_head}
1015010151
ggml_tensor * q_nope_absorbed = ggml_mul_mat(ctx0, model.layers[il].wk_b, q_nope);
1015110152
cb(q_nope_absorbed, "q_nope_absorbed", il);
1015210153

10153-
// {n_embd_head_qk_rope, n_tokens, n_head}
10154-
q_pe = ggml_permute(ctx0, q_pe, 0, 2, 1, 3);
10155-
cb(q_pe, "q_pe_perm", il);
10154+
// {kv_lora_rank, n_head, n_tokens}
10155+
q_nope_absorbed = ggml_permute(ctx0, q_nope_absorbed, 0, 2, 1, 3);
10156+
cb(q_nope_absorbed, "q_nope_absorbed_perm", il);
1015610157

10158+
// {n_embd_head_qk_rope + kv_lora_rank, n_head, n_tokens}
1015710159
// note: rope must go first for in-place context shifting in build_rope_shift()
1015810160
ggml_tensor * q_states = ggml_concat(ctx0, q_pe, q_nope_absorbed, 0);
1015910161
cb(q_states, "q_states", il);
1016010162

10161-
// {n_embd_head_qk_rope + kv_lora_rank, n_head, n_tokens}
10162-
q_states = ggml_permute(ctx0, q_states, 0, 2, 1, 3);
10163-
cb(q_states, "q_states_perm", il);
10164-
10165-
k_pe = ggml_reshape_2d(ctx0, k_pe, n_embd_head_qk_rope, n_tokens);
10166-
cb(k_pe, "k_pe_reshape", il);
10163+
kv_cmpr = ggml_reshape_3d(ctx0, kv_cmpr, kv_lora_rank, 1, n_tokens);
10164+
cb(kv_cmpr, "kv_cmpr_reshape", il);
1016710165

10166+
// {n_embd_head_qk_rope + kv_lora_rank, 1, n_tokens}
1016810167
ggml_tensor * k_states = ggml_concat(ctx0, k_pe, kv_cmpr, 0);
1016910168
cb(k_states, "k_states", il);
1017010169

10170+
// {kv_lora_rank, 1, n_tokens}
1017110171
ggml_tensor * v_states = kv_cmpr;
1017210172
cb(v_states, "v_states", il);
1017310173

0 commit comments

Comments
 (0)