Skip to content

Optimized DeepSeek V2/V3 implementation (MLA + flash attention) #12227

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -4141,6 +4141,29 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
else:
return []

# deepseek2-mla: split kv_b_proj into k_b_proj and (transposed) v_b_proj
if name.endswith("kv_b_proj.weight"):
name_kb = name.replace("kv_b_proj", "k_b_proj")
name_vb = name.replace("kv_b_proj", "v_b_proj")

n_head_kv = self.hparams["num_key_value_heads"]
v_head_dim = self.hparams["v_head_dim"]
qk_nope_head_dim = self.hparams["qk_nope_head_dim"]

assert data_torch.shape[0] == n_head_kv * (v_head_dim + qk_nope_head_dim)

kv_b = data_torch.view(n_head_kv, v_head_dim + qk_nope_head_dim, data_torch.shape[-1])
k_b, v_b = torch.split(kv_b, [qk_nope_head_dim, v_head_dim], dim=1)
k_b = k_b.transpose(1, 2)
k_b = k_b.reshape(n_head_kv * data_torch.shape[-1], qk_nope_head_dim)
v_b = v_b.reshape(n_head_kv * v_head_dim, data_torch.shape[-1])

return [
(self.map_tensor_name(name), data_torch),
(self.map_tensor_name(name_kb), k_b),
(self.map_tensor_name(name_vb), v_b)
]

return [(self.map_tensor_name(name), data_torch)]

def prepare_tensors(self):
Expand Down
18 changes: 18 additions & 0 deletions ggml/src/ggml-cuda/fattn.cu
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,8 @@ static void ggml_cuda_flash_attn_ext_vec_f16(ggml_backend_cuda_context & ctx, gg
FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16)

FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16)

FATTN_VEC_F16_CASE(576, GGML_TYPE_F16, GGML_TYPE_F16) // deepseek2-mla: for large 576 embedding to work
#else
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0)

Expand All @@ -161,6 +163,8 @@ static void ggml_cuda_flash_attn_ext_vec_f16(ggml_backend_cuda_context & ctx, gg
FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16)
FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16)
FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16)

FATTN_VEC_F16_CASE(576, GGML_TYPE_F16, GGML_TYPE_F16) // deepseek2-mla: for large 576 embedding to work
#endif // GGML_CUDA_FA_ALL_QUANTS

on_no_fattn_vec_case(Q->ne[0]);
Expand Down Expand Up @@ -228,6 +232,8 @@ static void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, gg
FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16)

FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16)

FATTN_VEC_F32_CASE(576, GGML_TYPE_F16, GGML_TYPE_F16) // deepseek2-mla: for large 576 embedding to work
#else
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0)

Expand All @@ -236,6 +242,8 @@ static void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, gg
FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16)
FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16)
FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16)

FATTN_VEC_F32_CASE(576, GGML_TYPE_F16, GGML_TYPE_F16) // deepseek2-mla: for large 576 embedding to work
#endif // GGML_CUDA_FA_ALL_QUANTS

on_no_fattn_vec_case(Q->ne[0]);
Expand All @@ -253,6 +261,16 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
const int warp_size = ggml_cuda_info().devices[ggml_cuda_get_device()].warp_size;
const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV);

// deepseek2-mla: special case to get the large 576 embedding to work
if (Q->ne[0] == 576) {
if (prec == GGML_PREC_DEFAULT && fast_fp16_available(cc)) {
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
} else {
ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
}
return;
}

if (cc >= GGML_CUDA_CC_OFFSET_AMD) {
#if defined(GGML_HIP_ROCWMMA_FATTN)
if (fp16_mma_available(cc)) {
Expand Down
6 changes: 6 additions & 0 deletions gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,8 @@ class MODEL_TENSOR(IntEnum):
ATTN_Q_B = auto()
ATTN_KV_A_MQA = auto()
ATTN_KV_B = auto()
ATTN_K_B = auto() # deepseek2-mla: split into attn_k_b
ATTN_V_B = auto() # deepseek2-mla: and (transposed) attn_v_b
ATTN_Q_A_NORM = auto()
ATTN_KV_A_NORM = auto()
FFN_SUB_NORM = auto()
Expand Down Expand Up @@ -543,6 +545,8 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.ATTN_Q_B: "blk.{bid}.attn_q_b",
MODEL_TENSOR.ATTN_KV_A_MQA: "blk.{bid}.attn_kv_a_mqa",
MODEL_TENSOR.ATTN_KV_B: "blk.{bid}.attn_kv_b",
MODEL_TENSOR.ATTN_K_B: "blk.{bid}.attn_k_b", # deepseek2-mla: split into attn_k_b
MODEL_TENSOR.ATTN_V_B: "blk.{bid}.attn_v_b", # deepseek2-mla: and (transposed) attn_v_b
MODEL_TENSOR.ATTN_Q_A_NORM: "blk.{bid}.attn_q_a_norm",
MODEL_TENSOR.ATTN_KV_A_NORM: "blk.{bid}.attn_kv_a_norm",
MODEL_TENSOR.ATTN_SUB_NORM: "blk.{bid}.attn_sub_norm",
Expand Down Expand Up @@ -1333,6 +1337,8 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.ATTN_Q_B,
MODEL_TENSOR.ATTN_KV_A_MQA,
MODEL_TENSOR.ATTN_KV_B,
MODEL_TENSOR.ATTN_K_B, # deepseek2-mla: split into attn_k_b
MODEL_TENSOR.ATTN_V_B, # deepseek2-mla: and (transposed) attn_v_b
MODEL_TENSOR.ATTN_Q_A_NORM,
MODEL_TENSOR.ATTN_KV_A_NORM,
MODEL_TENSOR.ATTN_OUT,
Expand Down
8 changes: 8 additions & 0 deletions gguf-py/gguf/tensor_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,6 +586,14 @@ class TensorNameMap:
"model.layers.{bid}.self_attn.kv_b_proj", # deepseek2
),

MODEL_TENSOR.ATTN_K_B: (
"model.layers.{bid}.self_attn.k_b_proj", # deepseek2-mla: split into attn_k_b
),

MODEL_TENSOR.ATTN_V_B: (
"model.layers.{bid}.self_attn.v_b_proj", # deepseek2-mla: and (transposed) attn_v_b
),

MODEL_TENSOR.ATTN_Q_A_NORM: (
"model.layers.{bid}.self_attn.q_a_layernorm", # deepseek2
),
Expand Down
21 changes: 4 additions & 17 deletions src/llama-arch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -999,6 +999,8 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_ATTN_Q_B, "blk.%d.attn_q_b" },
{ LLM_TENSOR_ATTN_KV_A_MQA, "blk.%d.attn_kv_a_mqa" },
{ LLM_TENSOR_ATTN_KV_B, "blk.%d.attn_kv_b" },
{ LLM_TENSOR_ATTN_K_B, "blk.%d.attn_k_b" }, // deepseek2-mla: split into attn_k_b
{ LLM_TENSOR_ATTN_V_B, "blk.%d.attn_v_b" }, // deepseek2-mla: and (transposed) attn_v_b
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
Expand Down Expand Up @@ -1333,23 +1335,8 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
{LLM_TENSOR_ATTN_Q_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_KV_A_MQA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_KV_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_DEC_ATTN_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_DEC_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_QKV, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_OUT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_FFN_GATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_FFN_DOWN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_FFN_UP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_FFN_DOWN_SHEXP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_FFN_GATE_SHEXP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_FFN_UP_SHEXP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_Q_A, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_Q_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_KV_A_MQA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_KV_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_K_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, // deepseek2-mla: split into attn_k_b
{LLM_TENSOR_ATTN_V_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, // deepseek2-mla: and (transposed) attn_v_b
{LLM_TENSOR_DEC_ATTN_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_DEC_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_DEC_ATTN_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
Expand Down
2 changes: 2 additions & 0 deletions src/llama-arch.h
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,8 @@ enum llm_tensor {
LLM_TENSOR_ATTN_Q_B,
LLM_TENSOR_ATTN_KV_A_MQA,
LLM_TENSOR_ATTN_KV_B,
LLM_TENSOR_ATTN_K_B, // deepseek2-mla: split into attn_k_b
LLM_TENSOR_ATTN_V_B, // deepseek2-mla: and (transposed) attn_v_b
LLM_TENSOR_ATTN_Q_A_NORM,
LLM_TENSOR_ATTN_KV_A_NORM,
LLM_TENSOR_ATTN_SUB_NORM,
Expand Down
18 changes: 16 additions & 2 deletions src/llama-kv-cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,22 @@ bool llama_kv_cache_init(
return false;
}

ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size);
ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size);
ggml_tensor * k;
ggml_tensor * v;
if (model.arch == LLM_ARCH_DEEPSEEK2) {
const uint32_t n_embd_head_qk_rope = hparams.n_rot;
const uint32_t kv_lora_rank = hparams.n_lora_kv;
k = ggml_new_tensor_1d(ctx, type_k, (kv_lora_rank+n_embd_head_qk_rope)*kv_size);
if (cparams.flash_attn) {
v = ggml_new_tensor_1d(ctx, type_v, 0); // FA reuses k in place of v
} else {
v = ggml_new_tensor_1d(ctx, type_v, kv_lora_rank*kv_size); // transposed for non-FA
}
} else {
k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size);
v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size);
}

ggml_format_name(k, "cache_k_l%d", i);
ggml_format_name(v, "cache_v_l%d", i);
cache.k_l.push_back(k);
Expand Down
8 changes: 5 additions & 3 deletions src/llama-model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2895,9 +2895,11 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_k_gqa}, 0);
}

layer.wkv_a_mqa = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i), {n_embd, kv_lora_rank + (n_embd_head_qk_rope)}, 0);
layer.wkv_b = create_tensor(tn(LLM_TENSOR_ATTN_KV_B, "weight", i), {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)}, 0);
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_head * ( n_embd_head_v), n_embd}, 0);
layer.wkv_a_mqa = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i), {n_embd, kv_lora_rank + n_embd_head_qk_rope }, 0);
layer.wkv_b = create_tensor(tn(LLM_TENSOR_ATTN_KV_B, "weight", i), {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)}, 0);
layer.wk_b = create_tensor(tn(LLM_TENSOR_ATTN_K_B, "weight", i), {n_embd_head_qk_nope, n_head * kv_lora_rank }, 0); // deepseek2-mla: split into attn_k_b
layer.wv_b = create_tensor(tn(LLM_TENSOR_ATTN_V_B, "weight", i), {kv_lora_rank, n_head * n_embd_head_v }, 0); // deepseek2-mla: and (transposed) attn_v_b
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_head * n_embd_head_v, n_embd }, 0);

layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);

Expand Down
2 changes: 2 additions & 0 deletions src/llama-model.h
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,8 @@ struct llama_layer {
struct ggml_tensor * wq_b = nullptr;
struct ggml_tensor * wkv_a_mqa = nullptr;
struct ggml_tensor * wkv_b = nullptr;
struct ggml_tensor * wk_b = nullptr; // deepseek2-mla: split into attn_k_b
struct ggml_tensor * wv_b = nullptr; // deepseek2-mla: and (transposed) attn_v_b
struct ggml_tensor * wq_cross = nullptr;
struct ggml_tensor * wk_cross = nullptr;
struct ggml_tensor * wv_cross = nullptr;
Expand Down
Loading
Loading