Skip to content

vulkan: Implement split_k for coopmat2 flash attention. #12627

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 2, 2025

Conversation

jeffbolznv
Copy link
Collaborator

This is stacked on #12559.

When using group query attention, we have one workgroup per KV batch and this can be very few workgroups (e.g. just 8 in some models). Enable split_k to spread the work across SMs. This helps a lot when the KV cache is large.

before:
  FLASH_ATTN_EXT(hs=64,nh=8,nr=4,kv=4096,nb=1,mask=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]):                 11924 runs -   100.01 us/run -  33.55 MFLOP/run - 335.52 GFLOPS
  FLASH_ATTN_EXT(hs=128,nh=8,nr=4,kv=4096,nb=1,mask=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]):                 7455 runs -   139.64 us/run -  67.11 MFLOP/run - 480.59 GFLOPS
  FLASH_ATTN_EXT(hs=64,nh=8,nr=4,kv=8192,nb=1,mask=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]):                  5964 runs -   190.01 us/run -  67.11 MFLOP/run - 353.19 GFLOPS
  FLASH_ATTN_EXT(hs=128,nh=8,nr=4,kv=8192,nb=1,mask=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]):                 3730 runs -   268.44 us/run - 134.22 MFLOP/run - 499.98 GFLOPS
  FLASH_ATTN_EXT(hs=64,nh=8,nr=4,kv=16384,nb=1,mask=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]):                 2984 runs -   371.01 us/run - 134.22 MFLOP/run - 361.76 GFLOPS
  FLASH_ATTN_EXT(hs=128,nh=8,nr=4,kv=16384,nb=1,mask=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]):                1865 runs -   566.61 us/run - 268.44 MFLOP/run - 473.76 GFLOPS


Z:\github\jeffbolznv\llama.cpp\build\bin\RelWithDebInfo>llama-bench.exe -m C:\models\meta-llama-3-8b-instruct.Q4_K_M.gguf -p 0 -n 4096,8192,16384 -fa 0,1 --repetitions 1
ggml_vulkan: Found 1 Vulkan devices:
ggml_vulkan: 0 = NVIDIA GeForce RTX 4070 (NVIDIA) | uma: 0 | fp16: 1 | warp size: 32 | shared memory: 49152 | matrix cores: NV_coopmat2
| model                          |       size |     params | backend    | ngl | fa |          test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | -: | ------------: | -------------------: |
| llama 8B Q4_K - Medium         |   4.58 GiB |     8.03 B | Vulkan     |  99 |  0 |        tg4096 |         72.29 ± 0.00 |
| llama 8B Q4_K - Medium         |   4.58 GiB |     8.03 B | Vulkan     |  99 |  0 |        tg8192 |         67.03 ± 0.00 |
| llama 8B Q4_K - Medium         |   4.58 GiB |     8.03 B | Vulkan     |  99 |  0 |       tg16384 |         57.92 ± 0.00 |
| llama 8B Q4_K - Medium         |   4.58 GiB |     8.03 B | Vulkan     |  99 |  1 |        tg4096 |         66.44 ± 0.00 |
| llama 8B Q4_K - Medium         |   4.58 GiB |     8.03 B | Vulkan     |  99 |  1 |        tg8192 |         57.09 ± 0.00 |
| llama 8B Q4_K - Medium         |   4.58 GiB |     8.03 B | Vulkan     |  99 |  1 |       tg16384 |         43.84 ± 0.00 |

after:
  FLASH_ATTN_EXT(hs=64,nh=8,nr=4,kv=4096,nb=1,mask=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]):                 44715 runs -    22.66 us/run -  33.55 MFLOP/run - 1.48 TFLOPS
  FLASH_ATTN_EXT(hs=128,nh=8,nr=4,kv=4096,nb=1,mask=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]):                32802 runs -    31.55 us/run -  67.11 MFLOP/run - 2.13 TFLOPS
  FLASH_ATTN_EXT(hs=64,nh=8,nr=4,kv=8192,nb=1,mask=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]):                 29820 runs -    34.41 us/run -  67.11 MFLOP/run - 1.95 TFLOPS
  FLASH_ATTN_EXT(hs=128,nh=8,nr=4,kv=8192,nb=1,mask=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]):                20888 runs -    49.37 us/run - 134.22 MFLOP/run - 2.72 TFLOPS
  FLASH_ATTN_EXT(hs=64,nh=8,nr=4,kv=16384,nb=1,mask=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]):                17904 runs -    57.42 us/run - 134.22 MFLOP/run - 2.34 TFLOPS
  FLASH_ATTN_EXT(hs=128,nh=8,nr=4,kv=16384,nb=1,mask=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]):                6714 runs -   152.57 us/run - 268.44 MFLOP/run - 1.76 TFLOPS


Z:\github\jeffbolznv\llama.cpp\build\bin\RelWithDebInfo>llama-bench.exe -m C:\models\meta-llama-3-8b-instruct.Q4_K_M.gguf -p 0 -n 4096,8192,16384 -fa 0,1 --repetitions 1
ggml_vulkan: Found 1 Vulkan devices:
ggml_vulkan: 0 = NVIDIA GeForce RTX 4070 (NVIDIA) | uma: 0 | fp16: 1 | warp size: 32 | shared memory: 49152 | matrix cores: NV_coopmat2
| model                          |       size |     params | backend    | ngl | fa |          test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | -: | ------------: | -------------------: |
| llama 8B Q4_K - Medium         |   4.58 GiB |     8.03 B | Vulkan     |  99 |  0 |        tg4096 |         72.31 ± 0.00 |
| llama 8B Q4_K - Medium         |   4.58 GiB |     8.03 B | Vulkan     |  99 |  0 |        tg8192 |         67.31 ± 0.00 |
| llama 8B Q4_K - Medium         |   4.58 GiB |     8.03 B | Vulkan     |  99 |  0 |       tg16384 |         57.95 ± 0.00 |
| llama 8B Q4_K - Medium         |   4.58 GiB |     8.03 B | Vulkan     |  99 |  1 |        tg4096 |         74.95 ± 0.00 |
| llama 8B Q4_K - Medium         |   4.58 GiB |     8.03 B | Vulkan     |  99 |  1 |        tg8192 |         70.85 ± 0.00 |
| llama 8B Q4_K - Medium         |   4.58 GiB |     8.03 B | Vulkan     |  99 |  1 |       tg16384 |         64.25 ± 0.00 |

@jeffbolznv jeffbolznv requested a review from 0cc4m March 28, 2025 14:12
@github-actions github-actions bot added testing Everything test related Vulkan Issues specific to the Vulkan backend ggml changes relating to the ggml tensor library for machine learning labels Mar 28, 2025
Copy link
Collaborator

@0cc4m 0cc4m left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

Can you resolve the conflict?

@jeffbolznv
Copy link
Collaborator Author

Thanks. Conflicts were minor, I'll merge after it passes CI.

When using group query attention, we have one workgroup per KV batch and this
can be very few workgroups (e.g. just 8 in some models). Enable split_k to
spread the work across SMs. This helps a lot when the KV cache is large.
@jeffbolznv jeffbolznv merged commit f01bd02 into ggml-org:master Apr 2, 2025
48 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ggml changes relating to the ggml tensor library for machine learning testing Everything test related Vulkan Issues specific to the Vulkan backend
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants