-
Notifications
You must be signed in to change notification settings - Fork 12.2k
CUDA: noncont MMVQ + batched bs1 MUL_MAT_ID #13014
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
CUDA: noncont MMVQ + batched bs1 MUL_MAT_ID #13014
Conversation
@@ -2035,97 +2056,75 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * | |||
dst_row.nb[2] = nb1; | |||
dst_row.nb[3] = nb1; | |||
|
|||
if (ne12 == 1) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I removed the ne12 == 1
code path because that can now be handled above. The following, seemingly large diff shown on Github is due to me changing the indentation of the else branch.
I am testing this, but with deepseek2 the ROPE is being run on the CPU, which destroys performance. If I change the |
I converted my gguf file from safetensors after the recent deepseek changes, @jukofyork does this explain the discrepancy? |
I tried with a freshly converted model from https://huggingface.co/deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct, but the ROPE is still being run on the CPU, so that doesn't seem to be the problem. I added a
If additionally I enable CUDA graphs (not sure if this is completely correct):
diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
index fdae68a41..17db36427 100644
--- a/ggml/src/ggml-cuda/ggml-cuda.cu
+++ b/ggml/src/ggml-cuda/ggml-cuda.cu
@@ -2488,12 +2488,14 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud
#endif
}
+#if 0
if (node->op == GGML_OP_MUL_MAT_ID) {
use_cuda_graph = false; // This node type is not supported by CUDA graph capture
#ifndef NDEBUG
GGML_LOG_DEBUG("%s: disabling CUDA graphs due to unsupported node type\n", __func__);
#endif
}
+#endif
if (node->op == GGML_OP_ADD && node->src[1] && node->src[1]->ne[1] > 1) {
// disable CUDA graphs for batch size > 1 for now.
diff --git a/src/llama-model.cpp b/src/llama-model.cpp
index 6b7bfecf3..728d109ed 100644
--- a/src/llama-model.cpp
+++ b/src/llama-model.cpp
@@ -10124,6 +10124,8 @@ struct llm_build_deepseek2 : public llm_graph_context {
ggml_row_size(kv_cmpr_pe->type, kv_lora_rank + n_embd_head_qk_rope),
ggml_row_size(kv_cmpr_pe->type, kv_lora_rank));
cb(k_pe, "k_pe", il);
+ k_pe = ggml_cont(ctx0, k_pe);
+ q_pe = ggml_cont(ctx0, q_pe);
q_pe = ggml_rope_ext(ctx0, q_pe, inp_pos, nullptr,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, |
Oops, I was only looking at the runtimes of CUDA kernels and forgot to check whether some ggml ops still run on the CPU 😅. Non-contiguous inputs for RoPE are already implemented, but the check in
|
Comparison with master on Windows (WSL, RTX 3090 Ti):
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I also see a similar speedup with K quantization.
GGML_ASSERT(dst->type == GGML_TYPE_F32); | ||
void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) { | ||
GGML_ASSERT( src1->type == GGML_TYPE_F32); | ||
GGML_ASSERT(!ids || ids->type == GGML_TYPE_I32); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I noticed that the code assumes that ids
is always contiguous, but this is only true with bs = 1, so that's something to keep in mind when extending this to bs > 1. Generally there is no guarantee that a tensor will be contiguous in any case, so it would be good to add a check as well just in case ids
is no longer contiguous in the future.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For bs > 1 I think the code needs to be written in a fundamentally different way; I agree that the code only works correctly for batch size 1 but I don't intend for it to be applied to any other batch sizes in the first place. I think for batch sizes > 1 the way to go is to first determine which rows should be fetched for which matrix and to then extend the MMQ code to allow for batched matrix multiplications with variable numbers of non-contiguous tokens. I think cuBLAS unfortunately only supports batched matrix multiplications where all matrices have the same shape.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I think this requires a fundamentally different algorithm than a batched gemm.
One option would be to use CUTLASS, since it already has a good grouped gemm implementation. When I looked into it, it only supported hopper GPUs, so I gave up on it, but that may have changed now.
It seems that cuBLAS has a grouped gemm now that could be used to implement this: https://docs.nvidia.com/cuda/cublas/index.html#cublas-t-gemmgroupedbatched (blog). However, without quantization support I am not sure how useful this would be.
CUDA: noncont MMVQ + batched bs1 MUL_MAT_ID (ggml-org#13014)
* CUDA: noncont MMVQ + batched bs1 MUL_MAT_ID * fix logic for RoPE support, CUDA graphs
This PR makes the following changes:
MUL_MAT_ID
. This reduces the number of kernel launches needed by a factor equal to the number of usedsrc0
matrices. It also eliminates the need for a call tocudaStreamSynchronize
prior to launching the kernels.MMV_MAX_ROWS
, it was supposed to be<=
but I accidentally wrote<
. In my testing this makes a small but measurable difference for Deepseek v2 Lite. Thank you to @jukofyork for pointing this out to me.Performance changes
Notes: