Skip to content

metal : improve FA + improve MoE #12612

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Mar 28, 2025
Merged

metal : improve FA + improve MoE #12612

merged 10 commits into from
Mar 28, 2025

Conversation

ggerganov
Copy link
Member

@ggerganov ggerganov commented Mar 27, 2025

Overview

  • Add FA kernels for head_size_k != head_size_v => V cache quantization support for DeepSeek
  • Improve FA-vec kernels, especially for head size 256 (i.e. Gemma) and large contexts
  • Extra TG performance improvement for quantized KV cache at large contexts for all models
  • Improve condition when to use mat-mat version of mul_mat_id based on rows per expert (huge bottleneck for DeepSeek models) (9c2b783)

M2 Studio results

Improved DeepSeek V2 Lite PP and TG perf

./scripts/compare-commits.sh master gg/metal-fa-diff-heads -m ./models/deepseek-v2-lite-chat/ggml-model-q8_0.gguf -fa 1 -p 1,2048 -n 256 -t 1
Model Test t/s master t/s gg/metal-fa-diff-heads Speedup
deepseek2 16B Q8_0 pp2048 150.94 968.34 6.42
deepseek2 16B Q8_0 tg256 94.87 97.81 1.03

Improved DeepSeek V2 Lite large context perf

make -j && ./bin/llama-batched-bench -m ../models/deepseek-v2-lite-chat/ggml-model-q8_0.gguf -c 16384 -b 2048 -ub 512 -npp 512,4096,8192 -ntg 128 -npl 1 -lv 1 -fa
  • master:
PP TG B N_KV T_PP s S_PP t/s T_TG s S_TG t/s T s S t/s
512 128 1 640 3.702 138.30 1.408 90.88 5.111 125.23
4096 128 1 4224 27.129 150.98 1.878 68.17 29.006 145.62
8192 128 1 8320 54.597 150.04 2.408 53.16 57.005 145.95
  • PR:
PP TG B N_KV T_PP s S_PP t/s T_TG s S_TG t/s T s S t/s
512 128 1 640 0.841 608.93 1.368 93.60 2.208 289.80
4096 128 1 4224 4.297 953.30 1.764 72.54 6.061 696.90
8192 128 1 8320 9.151 895.21 2.217 57.74 11.368 731.90

DeepSeek V3 IQ1_S

make -j && ./bin/llama-bench -m unsloth_DeepSeek-V3-0324-GGUF_UD-IQ1_S_DeepSeek-V3-0324-UD-IQ1_S-00001-of-00004.gguf -fa 1
model size params backend threads fa test t/s
deepseek2 671B IQ1_S 173.44 GiB 671.03 B Metal 16 1 pp512 50.89 ± 0.02
deepseek2 671B IQ1_S 173.44 GiB 671.03 B Metal 16 1 tg128 16.77 ± 0.00
make -j && ./bin/llama-batched-bench -m unsloth_DeepSeek-V3-0324-GGUF_UD-IQ1_S_DeepSeek-V3-0324-UD-IQ1_S-00001-of-00004.gguf -c 2048 -b 2048 -ub 512 -npp 1,1500 -ntg 128 -npl 1 -lv 1 -fa -ctk q8_0 -ctv q8_0
PP TG B N_KV T_PP s S_PP t/s T_TG s S_TG t/s T s S t/s
1 128 1 129 0.425 2.35 7.596 16.85 8.022 16.08
1500 128 1 1628 29.860 50.23 8.379 15.28 38.239 42.57

DeepSeek V2 Lite with Q8_0 KV cache

make -j && ./bin/llama-batched-bench -m ../models/deepseek-v2-lite-chat/ggml-model-q8_0.gguf -c 16384 -b 2048 -ub 512 -npp 512,4096,8192 -ntg 128 -npl 1 -lv 1 -fa -ctk q8_0 -ctv q8_0
  • master: not supported

  • PR:

PP TG B N_KV T_PP s S_PP t/s T_TG s S_TG t/s T s S t/s
512 128 1 640 0.679 754.26 1.407 90.97 2.086 306.83
4096 128 1 4224 4.527 904.79 1.695 75.51 6.222 678.87
8192 128 1 8320 9.983 820.60 2.044 62.62 12.027 691.78

Improved Gemma TG perf

./scripts/compare-commits.sh master gg/metal-fa-diff-heads -m models/gemma-2-2b/ggml-model-q4_0.gguf -m models/gemma-2-9b/ggml-model-q8_0.gguf -fa 1 -p 1,2048 -t 1
Model Test t/s master t/s gg/metal-fa-diff-heads Speedup
gemma2 2B Q4_0 pp2048 3187.94 3191.13 1.00
gemma2 2B Q4_0 tg128 155.10 161.68 1.04
gemma2 9B Q8_0 pp2048 899.85 900.85 1.00
gemma2 9B Q8_0 tg128 50.72 51.49 1.02
gemma3 4B Q8_0 pp2048 2205.61 2205.91 1.00
gemma3 4B Q8_0 tg128 91.24 93.53 1.03
F16 KV cache:
make -j && ./bin/llama-batched-bench -m ../models/gemma-2-9b/ggml-model-q4_k.gguf -c 16384 -b 2048 -ub 2048 -npp 512,4096,8192 -ntg 128 -npl 1 -lv 1 -fa
  • master:
PP TG B N_KV T_PP s S_PP t/s T_TG s S_TG t/s T s S t/s
512 128 1 640 1.002 510.79 2.200 58.17 3.203 199.83
4096 128 1 4224 5.306 771.95 3.056 41.89 8.362 505.16
8192 128 1 8320 12.242 669.16 4.075 31.41 16.318 509.88
  • PR:
PP TG B N_KV T_PP s S_PP t/s T_TG s S_TG t/s T s S t/s
512 128 1 640 0.992 515.92 2.117 60.46 3.110 205.81
4096 128 1 4224 5.294 773.77 2.842 45.05 8.135 519.23
8192 128 1 8320 12.331 664.34 3.668 34.89 16.000 520.01
Q8_0 KV cache:
make -j && ./bin/llama-batched-bench -m ../models/gemma-2-9b/ggml-model-q4_k.gguf -c 16384 -b 2048 -ub 2048 -npp 512,4096,8192 -ntg 128 -npl 1 -lv 1 -fa -ctk q8_0 -ctv q8_0
  • master:
PP TG B N_KV T_PP s S_PP t/s T_TG s S_TG t/s T s S t/s
512 128 1 640 0.813 629.80 2.304 55.56 3.117 205.34
4096 128 1 4224 5.498 745.03 3.386 37.80 8.884 475.47
8192 128 1 8320 12.969 631.68 4.603 27.81 17.571 473.50
  • PR:
PP TG B N_KV T_PP s S_PP t/s T_TG s S_TG t/s T s S t/s
512 128 1 640 0.788 650.04 2.168 59.03 2.956 216.51
4096 128 1 4224 5.459 750.33 2.812 45.52 8.271 510.71
8192 128 1 8320 12.893 635.40 3.572 35.83 16.465 505.32

@github-actions github-actions bot added testing Everything test related ggml changes relating to the ggml tensor library for machine learning Apple Metal https://en.wikipedia.org/wiki/Metal_(API) Nvidia GPU Issues specific to Nvidia GPUs Vulkan Issues specific to the Vulkan backend labels Mar 27, 2025
@ggerganov ggerganov force-pushed the gg/metal-fa-diff-heads branch from 6917e63 to 1e0f5ad Compare March 27, 2025 16:27
@ggerganov ggerganov changed the title metal : add FA kernels for different K, V head sizes metal : improve FA + improve MoE Mar 28, 2025
@ggerganov ggerganov merged commit b4ae508 into master Mar 28, 2025
56 checks passed
@ggerganov ggerganov deleted the gg/metal-fa-diff-heads branch March 28, 2025 18:22
@PkmX
Copy link

PkmX commented Mar 29, 2025

I'm getting this error when running https://huggingface.co/ggml-org/gemma-3-4b-it-GGUF which git blame points to this PR:

ggml_metal_init: allocating
ggml_metal_init: found device: Apple M1 Ultra
ggml_metal_init: picking default device: Apple M1 Ultra
ggml_metal_load_library: using embedded metal library
ggml_metal_load_library: error: Error Domain=MTLLibraryErrorDomain Code=3 "program_source:6106:12: error: variable length arrays are not supported in Metal
    o4_t lo[DV4/NL];
           ^
program_source:6405:18: note: in instantiation of function template specialization 'kernel_flash_attn_ext_vec<half __attribute__((ext_vector_type(4))), half __attribute__((ext_vector_type(4))), half __attribute__((ext_vector_type(4))), float, half, half __attribute__((ext_vector_type(4))), half __attribute__((ext_vector_type(4))), half __attribute__((ext_vector_type(4))), 1, &dequantize_f16_t4, half __attribute__((ext_vector_type(4))), 1, &dequantize_f16_t4, 128, 128, 128, 1, 32>' requested here
typedef decltype(kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 128, 128, 128>) flash_attn_ext_vec_t;
                 ^
" UserInfo={NSLocalizedDescription=program_source:6106:12: error: variable length arrays are not supported in Metal
    o4_t lo[DV4/NL];
           ^
program_source:6405:18: note: in instantiation of function template specialization 'kernel_flash_attn_ext_vec<half __attribute__((ext_vector_type(4))), half __attribute__((ext_vector_type(4))), half __attribute__((ext_vector_type(4))), float, half, half __attribute__((ext_vector_type(4))), half __attribute__((ext_vector_type(4))), half __attribute__((ext_vector_type(4))), 1, &dequantize_f16_t4, half __attribute__((ext_vector_type(4))), 1, &dequantize_f16_t4, 128, 128, 128, 1, 32>' requested here
typedef decltype(kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 128, 128, 128>) flash_attn_ext_vec_t;
                 ^
}
ggml_metal_init: error: metal library is nil
ggml_backend_metal_device_init: error: failed to allocate context
llama_init_from_model: failed to initialize the context: failed to initialize Metal backend
common_init_from_params: failed to create context with model 'models/gemma-3-4b-it-Q8_0.gguf'
srv    load_model: failed to load model, 'models/gemma-3-4b-it-Q8_0.gguf'
srv    operator(): operator(): cleaning up before exit...
libc++abi: terminating
main: exiting due to model loading error

@ggerganov
Copy link
Member Author

@PkmX Can you confirm that #12659 fixes the error?

@PkmX
Copy link

PkmX commented Mar 30, 2025

@ggerganov No, still getting the same error even after a clean rebuild.

@ggerganov
Copy link
Member Author

And you are sure that you checkout the branch of #12659?

@PkmX
Copy link

PkmX commented Mar 30, 2025

Yes I applied the patch onto current master.

@ggerganov
Copy link
Member Author

ggerganov commented Mar 30, 2025

@PkmX I pushed one more change there - can you check if it works now?

It's strange why it fails for you. I tested that the master works correctly on M1 Pro, M2 Ultra and M4 Max.

@PkmX
Copy link

PkmX commented Mar 30, 2025

I did a little more debugging and found that the flash_attn_ext_vec_t typedef with NE=128 is causing NL to be 32/128 (0), and since DV4/NL is now division-by-zero, it is not treated as a constant.

The following patch seems to workaround the problem.

diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal
index 54a92247..e82bb5dd 100644
--- a/ggml/src/ggml-metal/ggml-metal.metal
+++ b/ggml/src/ggml-metal/ggml-metal.metal
@@ -3958,7 +3958,7 @@ kernel void kernel_flash_attn_ext_vec(
     half,  half4, \
            half4

-typedef decltype(kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 128, 128, 128>) flash_attn_ext_vec_t;
+typedef decltype(kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 128, 128, 4>) flash_attn_ext_vec_t;

 template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]]  kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4,             1, dequantize_f16_t4,  half4,       1, dequantize_f16_t4,  128, 128, 4>;
 #if defined(GGML_METAL_USE_BF16)

@ggerganov
Copy link
Member Author

Right, this is indeed a mistake. Thanks for pin-pointing it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Apple Metal https://en.wikipedia.org/wiki/Metal_(API) ggml changes relating to the ggml tensor library for machine learning Nvidia GPU Issues specific to Nvidia GPUs testing Everything test related Vulkan Issues specific to the Vulkan backend
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants