Skip to content

Commit 8c02442

Browse files
committed
Merged using squash to remove all noise commit messages
1 parent e391d3e commit 8c02442

13 files changed

+408
-114
lines changed

convert_hf_to_gguf.py

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4335,6 +4335,10 @@ def set_vocab(self):
43354335
self._set_vocab_gpt2()
43364336

43374337
def set_gguf_parameters(self):
4338+
4339+
# note: deepseek2 using MLA converts into MQA (ie: GQA with 1 group)
4340+
self.hparams["num_key_value_heads"] = 1
4341+
43384342
super().set_gguf_parameters()
43394343
hparams = self.hparams
43404344

@@ -4343,8 +4347,13 @@ def set_gguf_parameters(self):
43434347
if "q_lora_rank" in hparams and hparams["q_lora_rank"] is not None:
43444348
self.gguf_writer.add_q_lora_rank(hparams["q_lora_rank"])
43454349
self.gguf_writer.add_kv_lora_rank(hparams["kv_lora_rank"])
4346-
self.gguf_writer.add_key_length(hparams["qk_nope_head_dim"] + hparams["qk_rope_head_dim"])
4347-
self.gguf_writer.add_value_length(hparams["v_head_dim"])
4350+
4351+
# note: deepseek2 using MLA converts into MQA with larger heads, then decompresses to MHA
4352+
self.gguf_writer.add_key_length(hparams["kv_lora_rank"] + hparams["qk_rope_head_dim"])
4353+
self.gguf_writer.add_value_length(hparams["kv_lora_rank"])
4354+
self.gguf_writer.add_key_length_mla(hparams["qk_nope_head_dim"] + hparams["qk_rope_head_dim"])
4355+
self.gguf_writer.add_value_length_mla(hparams["v_head_dim"])
4356+
43484357
self.gguf_writer.add_expert_feed_forward_length(hparams["moe_intermediate_size"])
43494358
self.gguf_writer.add_expert_count(hparams["n_routed_experts"])
43504359
self.gguf_writer.add_expert_shared_count(hparams["n_shared_experts"])
@@ -4413,6 +4422,28 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
44134422
else:
44144423
return []
44154424

4425+
# note: MLA with the absorption optimization, needs these two split and k_b_proj transposed
4426+
if name.endswith("kv_b_proj.weight"):
4427+
name_kb = name.replace("kv_b_proj", "k_b_proj")
4428+
name_vb = name.replace("kv_b_proj", "v_b_proj")
4429+
4430+
n_head_kv = self.hparams["num_key_value_heads"]
4431+
v_head_dim = self.hparams["v_head_dim"]
4432+
qk_nope_head_dim = self.hparams["qk_nope_head_dim"]
4433+
4434+
assert data_torch.shape[0] == n_head_kv * (v_head_dim + qk_nope_head_dim)
4435+
4436+
kv_b = data_torch.view(n_head_kv, v_head_dim + qk_nope_head_dim, data_torch.shape[-1])
4437+
k_b, v_b = torch.split(kv_b, [qk_nope_head_dim, v_head_dim], dim=1)
4438+
k_b = k_b.transpose(1, 2)
4439+
k_b = k_b.reshape(n_head_kv * data_torch.shape[-1], qk_nope_head_dim)
4440+
v_b = v_b.reshape(n_head_kv * v_head_dim, data_torch.shape[-1])
4441+
4442+
return [
4443+
(self.map_tensor_name(name_kb), k_b),
4444+
(self.map_tensor_name(name_vb), v_b)
4445+
]
4446+
44164447
return [(self.map_tensor_name(name), data_torch)]
44174448

44184449
def prepare_tensors(self):

gguf-py/gguf/constants.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,8 @@ class Attention:
138138
REL_BUCKETS_COUNT = "{arch}.attention.relative_buckets_count"
139139
SLIDING_WINDOW = "{arch}.attention.sliding_window"
140140
SCALE = "{arch}.attention.scale"
141+
KEY_LENGTH_MLA = "{arch}.attention.key_length_mla"
142+
VALUE_LENGTH_MLA = "{arch}.attention.value_length_mla"
141143

142144
class Rope:
143145
DIMENSION_COUNT = "{arch}.rope.dimension_count"
@@ -377,6 +379,8 @@ class MODEL_TENSOR(IntEnum):
377379
ATTN_Q_B = auto()
378380
ATTN_KV_A_MQA = auto()
379381
ATTN_KV_B = auto()
382+
ATTN_K_B = auto()
383+
ATTN_V_B = auto()
380384
ATTN_Q_A_NORM = auto()
381385
ATTN_KV_A_NORM = auto()
382386
FFN_SUB_NORM = auto()
@@ -581,6 +585,8 @@ class MODEL_TENSOR(IntEnum):
581585
MODEL_TENSOR.ATTN_Q_B: "blk.{bid}.attn_q_b",
582586
MODEL_TENSOR.ATTN_KV_A_MQA: "blk.{bid}.attn_kv_a_mqa",
583587
MODEL_TENSOR.ATTN_KV_B: "blk.{bid}.attn_kv_b",
588+
MODEL_TENSOR.ATTN_K_B: "blk.{bid}.attn_k_b",
589+
MODEL_TENSOR.ATTN_V_B: "blk.{bid}.attn_v_b",
584590
MODEL_TENSOR.ATTN_Q_A_NORM: "blk.{bid}.attn_q_a_norm",
585591
MODEL_TENSOR.ATTN_KV_A_NORM: "blk.{bid}.attn_kv_a_norm",
586592
MODEL_TENSOR.ATTN_SUB_NORM: "blk.{bid}.attn_sub_norm",
@@ -1451,6 +1457,8 @@ class MODEL_TENSOR(IntEnum):
14511457
MODEL_TENSOR.ATTN_Q_B,
14521458
MODEL_TENSOR.ATTN_KV_A_MQA,
14531459
MODEL_TENSOR.ATTN_KV_B,
1460+
MODEL_TENSOR.ATTN_K_B,
1461+
MODEL_TENSOR.ATTN_V_B,
14541462
MODEL_TENSOR.ATTN_Q_A_NORM,
14551463
MODEL_TENSOR.ATTN_KV_A_NORM,
14561464
MODEL_TENSOR.ATTN_OUT,

gguf-py/gguf/gguf_writer.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -689,6 +689,12 @@ def add_key_length(self, length: int) -> None:
689689
def add_value_length(self, length: int) -> None:
690690
self.add_uint32(Keys.Attention.VALUE_LENGTH.format(arch=self.arch), length)
691691

692+
def add_key_length_mla(self, length: int) -> None:
693+
self.add_uint32(Keys.Attention.KEY_LENGTH_MLA.format(arch=self.arch), length)
694+
695+
def add_value_length_mla(self, length: int) -> None:
696+
self.add_uint32(Keys.Attention.VALUE_LENGTH_MLA.format(arch=self.arch), length)
697+
692698
def add_max_alibi_bias(self, bias: float) -> None:
693699
self.add_float32(Keys.Attention.MAX_ALIBI_BIAS.format(arch=self.arch), bias)
694700

gguf-py/gguf/tensor_mapping.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -656,6 +656,14 @@ class TensorNameMap:
656656
"model.layers.{bid}.self_attn.kv_b_proj", # deepseek2
657657
),
658658

659+
MODEL_TENSOR.ATTN_K_B: (
660+
"model.layers.{bid}.self_attn.k_b_proj", # deepseek2
661+
),
662+
663+
MODEL_TENSOR.ATTN_V_B: (
664+
"model.layers.{bid}.self_attn.v_b_proj", # deepseek2
665+
),
666+
659667
MODEL_TENSOR.ATTN_Q_A_NORM: (
660668
"model.layers.{bid}.self_attn.q_a_layernorm", # deepseek2
661669
),

src/llama-arch.cpp

Lines changed: 6 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,8 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
135135
{ LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, "%s.attention.relative_buckets_count" },
136136
{ LLM_KV_ATTENTION_SLIDING_WINDOW, "%s.attention.sliding_window" },
137137
{ LLM_KV_ATTENTION_SCALE, "%s.attention.scale" },
138+
{ LLM_KV_ATTENTION_KEY_LENGTH_MLA, "%s.attention.key_length_mla" },
139+
{ LLM_KV_ATTENTION_VALUE_LENGTH_MLA, "%s.attention.value_length_mla" },
138140

139141
{ LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" },
140142
{ LLM_KV_ROPE_DIMENSION_SECTIONS, "%s.rope.dimension_sections" },
@@ -1030,6 +1032,8 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
10301032
{ LLM_TENSOR_ATTN_Q_B, "blk.%d.attn_q_b" },
10311033
{ LLM_TENSOR_ATTN_KV_A_MQA, "blk.%d.attn_kv_a_mqa" },
10321034
{ LLM_TENSOR_ATTN_KV_B, "blk.%d.attn_kv_b" },
1035+
{ LLM_TENSOR_ATTN_K_B, "blk.%d.attn_k_b" },
1036+
{ LLM_TENSOR_ATTN_V_B, "blk.%d.attn_v_b" },
10331037
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
10341038
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
10351039
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
@@ -1471,23 +1475,8 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
14711475
{LLM_TENSOR_ATTN_Q_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
14721476
{LLM_TENSOR_ATTN_KV_A_MQA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
14731477
{LLM_TENSOR_ATTN_KV_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1474-
{LLM_TENSOR_DEC_ATTN_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1475-
{LLM_TENSOR_DEC_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1476-
{LLM_TENSOR_ATTN_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1477-
{LLM_TENSOR_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1478-
{LLM_TENSOR_ATTN_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1479-
{LLM_TENSOR_ATTN_QKV, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1480-
{LLM_TENSOR_ATTN_OUT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1481-
{LLM_TENSOR_FFN_GATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1482-
{LLM_TENSOR_FFN_DOWN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1483-
{LLM_TENSOR_FFN_UP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1484-
{LLM_TENSOR_FFN_DOWN_SHEXP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1485-
{LLM_TENSOR_FFN_GATE_SHEXP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1486-
{LLM_TENSOR_FFN_UP_SHEXP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1487-
{LLM_TENSOR_ATTN_Q_A, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1488-
{LLM_TENSOR_ATTN_Q_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1489-
{LLM_TENSOR_ATTN_KV_A_MQA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1490-
{LLM_TENSOR_ATTN_KV_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1478+
{LLM_TENSOR_ATTN_K_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1479+
{LLM_TENSOR_ATTN_V_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
14911480
{LLM_TENSOR_DEC_ATTN_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
14921481
{LLM_TENSOR_DEC_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
14931482
{LLM_TENSOR_DEC_ATTN_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},

src/llama-arch.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,8 @@ enum llm_kv {
139139
LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT,
140140
LLM_KV_ATTENTION_SLIDING_WINDOW,
141141
LLM_KV_ATTENTION_SCALE,
142+
LLM_KV_ATTENTION_KEY_LENGTH_MLA,
143+
LLM_KV_ATTENTION_VALUE_LENGTH_MLA,
142144

143145
LLM_KV_ROPE_DIMENSION_COUNT,
144146
LLM_KV_ROPE_DIMENSION_SECTIONS,
@@ -299,6 +301,8 @@ enum llm_tensor {
299301
LLM_TENSOR_ATTN_Q_B,
300302
LLM_TENSOR_ATTN_KV_A_MQA,
301303
LLM_TENSOR_ATTN_KV_B,
304+
LLM_TENSOR_ATTN_K_B,
305+
LLM_TENSOR_ATTN_V_B,
302306
LLM_TENSOR_ATTN_Q_A_NORM,
303307
LLM_TENSOR_ATTN_KV_A_NORM,
304308
LLM_TENSOR_ATTN_SUB_NORM,

src/llama-context.cpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include <cstring>
1111
#include <stdexcept>
1212
#include <cinttypes>
13+
#include <math.h>
1314

1415
//
1516
// llama_context
@@ -473,7 +474,6 @@ ggml_tensor * llama_context::build_rope_shift(
473474
const auto & n_ctx_orig = cparams.n_ctx_orig_yarn;
474475

475476
const auto & yarn_ext_factor = cparams.yarn_ext_factor;
476-
const auto & yarn_attn_factor = cparams.yarn_attn_factor;
477477
const auto & yarn_beta_fast = cparams.yarn_beta_fast;
478478
const auto & yarn_beta_slow = cparams.yarn_beta_slow;
479479

@@ -482,6 +482,10 @@ ggml_tensor * llama_context::build_rope_shift(
482482
const auto & n_rot = hparams.n_rot;
483483
const auto & rope_type = hparams.rope_type;
484484

485+
// See llm_build_deepseek2() for why attn_factor has to be scaled for YaRN RoPE to work correctly.
486+
// See https://github.com/ggerganov/llama.cpp/discussions/7416 for detailed explanation.
487+
const float yarn_attn_factor_scaled = model.arch == LLM_ARCH_DEEPSEEK2 ? 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale)) : cparams.yarn_attn_factor;
488+
485489
ggml_tensor * tmp;
486490

487491
if (ggml_is_quantized(cur->type)) {
@@ -500,14 +504,14 @@ ggml_tensor * llama_context::build_rope_shift(
500504

501505
tmp = ggml_rope_ext_inplace(ctx0, tmp,
502506
shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
503-
yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow);
507+
yarn_ext_factor, yarn_attn_factor_scaled, yarn_beta_fast, yarn_beta_slow);
504508

505509
tmp = ggml_cpy(ctx0, tmp, cur);
506510
} else {
507511
// we rotate only the first n_rot dimensions
508512
tmp = ggml_rope_ext_inplace(ctx0, cur,
509513
shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
510-
yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow);
514+
yarn_ext_factor, yarn_attn_factor_scaled, yarn_beta_fast, yarn_beta_slow);
511515
}
512516

513517
return tmp;

0 commit comments

Comments
 (0)