Skip to content

Commit 69699be

Browse files
CUDA: fix q_nope_absorbed prec for DS 2 Lite f16 (#13137)
1 parent 85f36e5 commit 69699be

File tree

3 files changed

+5
-4
lines changed

3 files changed

+5
-4
lines changed

ggml/include/ggml.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -393,8 +393,8 @@ extern "C" {
393393

394394
// precision
395395
enum ggml_prec {
396-
GGML_PREC_DEFAULT,
397-
GGML_PREC_F32,
396+
GGML_PREC_DEFAULT = 0, // stored as ggml_tensor.op_params, 0 by default
397+
GGML_PREC_F32 = 10,
398398
};
399399

400400
// model file types

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1935,8 +1935,8 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
19351935
ggml_cuda_mul_mat_vec(ctx, src0, src1, nullptr, dst);
19361936
} else if (!split && use_mul_mat_vec_q) {
19371937
ggml_cuda_mul_mat_vec_q(ctx, src0, src1, nullptr, dst);
1938-
} else if (!split && src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || !any_gpus_with_slow_fp16)
1939-
&& !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
1938+
} else if (!split && src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || !any_gpus_with_slow_fp16) &&
1939+
dst->op_params[0] == GGML_PREC_DEFAULT && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
19401940
// general KQ + KQV multi-batch without FlashAttention
19411941
ggml_cuda_mul_mat_batched_cublas(ctx, src0, src1, dst);
19421942
} else if (use_mul_mat_vec) {

src/llama-model.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10149,6 +10149,7 @@ struct llm_build_deepseek2 : public llm_graph_context {
1014910149

1015010150
// {n_embd_head_qk_nope, kv_lora_rank, n_head} x {n_embd_head_qk_nope, n_tokens, n_head}
1015110151
ggml_tensor * q_nope_absorbed = ggml_mul_mat(ctx0, model.layers[il].wk_b, q_nope);
10152+
ggml_mul_mat_set_prec(q_nope_absorbed, GGML_PREC_F32);
1015210153
cb(q_nope_absorbed, "q_nope_absorbed", il);
1015310154

1015410155
// {kv_lora_rank, n_head, n_tokens}

0 commit comments

Comments
 (0)