Skip to content

CUDA: noncont MMVQ + batched bs1 MUL_MAT_ID #13014

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 22, 2025

Conversation

JohannesGaessler
Copy link
Collaborator

@JohannesGaessler JohannesGaessler commented Apr 18, 2025

This PR makes the following changes:

  • Extend MMVQ to handle noncontiguous and batched inputs. This enables doing the KQ calculation in a single kernel launch when using quantized K and no FlashAttention.
  • Add FP32 support to MMV. Basically no change in terms of performance for CUDA. hipBLAS seems to have been very bad for batch size 1 so on my RX 6800 this provides a speedup of almost 5x for FP32 models.
  • Extend MMV and MMVQ with support for batch size 1 MUL_MAT_ID. This reduces the number of kernel launches needed by a factor equal to the number of used src0 matrices. It also eliminates the need for a call to cudaStreamSynchronize prior to launching the kernels.
  • Fix the condition for MMV_MAX_ROWS, it was supposed to be <= but I accidentally wrote <. In my testing this makes a small but measurable difference for Deepseek v2 Lite. Thank you to @jukofyork for pointing this out to me.
Performance changes
GPU Model Cache type K Test t/s master t/s PR Speedup
RX 6800 deepseek 2 16B Q4_0 F16 tg128 27.26 52.74 1.93
RX 6800 llama 8B Q4_0 F16 tg128 62.02 60.90 0.98
RX 6800 llama 8B Q4_0 q8_0 tg4096 38.98 49.65 1.27
RX 6800 llama 1B all F32 F16 tg128 16.08 76.08 4.73
P40 deepseek 2 16B Q4_0 F16 tg128 47.78 62.87 1.32
P40 llama 8B Q4_0 F16 tg128 51.02 49.91 0.98
P40 llama 8B Q4_0 q8_0 tg4096 31.96 31.71 0.99
P40 llama 1B all F32 F16 tg128 53.65 57.46 1.07
2x P40 deepseek 2 16B F16 F16 tg128 34.96 37.45 1.07
RTX 3090 deepseek 2 16B Q4_0 F16 tg128 76.58 130.17 1.70
RTX 3090 llama 8B Q4_0 F16 tg128 140.00 138.90 0.99
RTX 3090 llama 8B Q4_0 q8_0 tg4096 86.82 93.21 1.07
RTX 3090 llama 1B all F32 F16 tg128 151.58 152.11 1.00
RTX 4090 deepseek 2 16B Q4_0 F16 tg128 55.84 123.55 2.21
RTX 4090 llama 8B Q4_0 F16 tg128 158.67 157.21 0.99
RTX 4090 llama 8B Q4_0 q8_0 tg4096 104.08 123.94 1.19
RTX 4090 llama 1B all F32 F16 tg128 167.65 168.72 1.01
2x RTX 4090 deepseek 2 16B F16 F16 tg128 69.80 87.62 1.26

Notes:

  • The performance of Deepseek v2 Lite is better on my RTX 3090 than my RTX 4090. I think there is still significant kernel launch overhead so the difference comes from my 3090 being paired with a faster CPU.
  • There is a small performance regression for non-MoE models; I think even disregarding MoE this is a worthwhile tradeoff since user code now has more flexibility regarding tensor shapes for the CUDA backend.
  • The approach I used for this PR only works for batch size 1 where the arithmetic intensity is terrible anyways because you can use each loaded weight only once. I have an idea that I think will work well for MMQ but would make it necessary to compile twice as many template specializations. For floating-point data I have so far not been able to think of a good solution that doesn't involve writing my own GEMM kernel (which would be a lot of work).

@github-actions github-actions bot added testing Everything test related Nvidia GPU Issues specific to Nvidia GPUs ggml changes relating to the ggml tensor library for machine learning labels Apr 18, 2025
@@ -2035,97 +2056,75 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
dst_row.nb[2] = nb1;
dst_row.nb[3] = nb1;

if (ne12 == 1) {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed the ne12 == 1 code path because that can now be handled above. The following, seemingly large diff shown on Github is due to me changing the indentation of the else branch.

@slaren
Copy link
Member

slaren commented Apr 18, 2025

I am testing this, but with deepseek2 the ROPE is being run on the CPU, which destroys performance. If I change the supports_op to force it to run on CUDA, the performance is much better. Also, the exception that disables CUDA graphs with MUL_MAT_ID could be removed now, which would improve performance further. With both of these fixes the performance that I get on my 3090 Ti is similar to what you report here, otherwise it is much lower.

@JohannesGaessler
Copy link
Collaborator Author

I converted my gguf file from safetensors after the recent deepseek changes, @jukofyork does this explain the discrepancy?

@slaren
Copy link
Member

slaren commented Apr 18, 2025

I tried with a freshly converted model from https://huggingface.co/deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct, but the ROPE is still being run on the CPU, so that doesn't seem to be the problem.

I added a ggml_cont to the graph so that ROPE can still be run on CUDA:

model size params backend ngl test t/s
deepseek2 16B Q4_0 8.29 GiB 15.71 B CUDA 99 tg128 94.89 ± 1.95
deepseek2 16B Q8_0 15.55 GiB 15.71 B CUDA 99 tg128 85.38 ± 0.36

If additionally I enable CUDA graphs (not sure if this is completely correct):

model size params backend ngl test t/s
deepseek2 16B Q4_0 8.29 GiB 15.71 B CUDA 99 tg128 182.85 ± 0.94
deepseek2 16B Q8_0 15.55 GiB 15.71 B CUDA 99 tg128 151.66 ± 0.80
diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
index fdae68a41..17db36427 100644
--- a/ggml/src/ggml-cuda/ggml-cuda.cu
+++ b/ggml/src/ggml-cuda/ggml-cuda.cu
@@ -2488,12 +2488,14 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud
 #endif
         }

+#if 0
         if (node->op == GGML_OP_MUL_MAT_ID) {
             use_cuda_graph = false; // This node type is not supported by CUDA graph capture
 #ifndef NDEBUG
             GGML_LOG_DEBUG("%s: disabling CUDA graphs due to unsupported node type\n", __func__);
 #endif
         }
+#endif

         if (node->op == GGML_OP_ADD && node->src[1] && node->src[1]->ne[1] > 1) {
             // disable CUDA graphs for batch size > 1 for now.
diff --git a/src/llama-model.cpp b/src/llama-model.cpp
index 6b7bfecf3..728d109ed 100644
--- a/src/llama-model.cpp
+++ b/src/llama-model.cpp
@@ -10124,6 +10124,8 @@ struct llm_build_deepseek2 : public llm_graph_context {
                         ggml_row_size(kv_cmpr_pe->type, kv_lora_rank + n_embd_head_qk_rope),
                         ggml_row_size(kv_cmpr_pe->type, kv_lora_rank));
                 cb(k_pe, "k_pe", il);
+                k_pe = ggml_cont(ctx0, k_pe);
+                q_pe = ggml_cont(ctx0, q_pe);

                 q_pe = ggml_rope_ext(ctx0, q_pe, inp_pos, nullptr,
                         n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,

@JohannesGaessler
Copy link
Collaborator Author

Oops, I was only looking at the runtimes of CUDA kernels and forgot to check whether some ggml ops still run on the CPU 😅. Non-contiguous inputs for RoPE are already implemented, but the check in ggml_backend_cuda_device_supports_op was wrong. CUDA graphs can be enabled for MUL_MAT_ID if the batch size is 1 because for other batch sizes a workaround is still needed (but I think for batch sizes > 1 CUDA graphs are currently disabled anyways). With those two fixes the performance on my systems now is:

GPU Model Test t/s master t/s PR old t/s PR new Speedup vs. master Speedup vs. PR old
RX 6800 deepseek 2 16B Q4_0 pp512 565.29 - 637.31 1.13 -
RX 6800 deepseek 2 16B Q4_0 tg128 27.26 52.74 71.95 2.64 1.36
P40 deepseek 2 16B Q4_0 pp512 727.03 - 841.71 1.16 -
P40 deepseek 2 16B Q4_0 tg128 47.78 62.87 71.13 1.49 1.13
RTX 3090 deepseek 2 16B Q4_0 pp512 1885.98 - 2344.70 1.24 -
RTX 3090 deepseek 2 16B Q4_0 tg128 76.58 130.17 184.17 2.40 1.41
RTX 4090 deepseek 2 16B Q4_0 pp512 2882.90 - 3570.49 1.24 -
RTX 4090 deepseek 2 16B Q4_0 tg128 55.84 123.55 223.27 4.00 1.81

@slaren
Copy link
Member

slaren commented Apr 19, 2025

Comparison with master on Windows (WSL, RTX 3090 Ti):

Model Test t/s master t/s cuda-moe-mmv-2 Speedup
deepseek2 16B Q4_0 tg128 29.13 190.94 6.56
llama 8x7B Q3_K_S tg128 45.88 70.85 1.54

Copy link
Member

@slaren slaren left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also see a similar speedup with K quantization.

GGML_ASSERT(dst->type == GGML_TYPE_F32);
void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) {
GGML_ASSERT( src1->type == GGML_TYPE_F32);
GGML_ASSERT(!ids || ids->type == GGML_TYPE_I32);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I noticed that the code assumes that ids is always contiguous, but this is only true with bs = 1, so that's something to keep in mind when extending this to bs > 1. Generally there is no guarantee that a tensor will be contiguous in any case, so it would be good to add a check as well just in case ids is no longer contiguous in the future.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For bs > 1 I think the code needs to be written in a fundamentally different way; I agree that the code only works correctly for batch size 1 but I don't intend for it to be applied to any other batch sizes in the first place. I think for batch sizes > 1 the way to go is to first determine which rows should be fetched for which matrix and to then extend the MMQ code to allow for batched matrix multiplications with variable numbers of non-contiguous tokens. I think cuBLAS unfortunately only supports batched matrix multiplications where all matrices have the same shape.

Copy link
Member

@slaren slaren Apr 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I think this requires a fundamentally different algorithm than a batched gemm.

One option would be to use CUTLASS, since it already has a good grouped gemm implementation. When I looked into it, it only supported hopper GPUs, so I gave up on it, but that may have changed now.

It seems that cuBLAS has a grouped gemm now that could be used to implement this: https://docs.nvidia.com/cuda/cublas/index.html#cublas-t-gemmgroupedbatched (blog). However, without quantization support I am not sure how useful this would be.

@JohannesGaessler JohannesGaessler merged commit 658987c into ggml-org:master Apr 22, 2025
47 of 48 checks passed
BradHutchings added a commit to BradHutchings/llama-server-one that referenced this pull request Apr 22, 2025
CUDA: noncont MMVQ + batched bs1 MUL_MAT_ID (ggml-org#13014)
Nexesenex added a commit to Nexesenex/croco.cpp that referenced this pull request Apr 24, 2025
pockers21 pushed a commit to pockers21/llama.cpp that referenced this pull request Apr 28, 2025
* CUDA: noncont MMVQ + batched bs1 MUL_MAT_ID

* fix logic for RoPE support, CUDA graphs
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ggml changes relating to the ggml tensor library for machine learning Nvidia GPU Issues specific to Nvidia GPUs testing Everything test related
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants