Skip to content

metal : handle -inf values in FA kernel #7434

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 21, 2024
Merged

Conversation

ggerganov
Copy link
Member

@ggerganov ggerganov commented May 21, 2024

Fix partial offload with Flash Attention + Metal

As observed in https://github.com/ggerganov/llama.cpp/pull/7192/files#r1596612314, it seems that when simdgroup_load is used to load data that contains -INFINITY (such as the causal mask) Metal produces NaNs. In #7192 we made a poor workaround where the F32 -> F16 copy kernel explicitly checked for -INFINITY values. However, this does not solve all problems because with partial offload, the mask is casted on the CPU and later copied to the Metal backend directly

In this PR, we no longer use simdgroup_load on the mask data and instead perform the operations in scalar form. This seems to fix the problem in a more general way without affecting the performance

Repro

# on master produces garbage
./main -m models/phi-3-4k-instruct/ggml-model-f16.gguf -e -p "Hello world. My name is" -fa -ngl 10
# on master produces NaNs in the last layer
./eval-callback -m models/phi-3-4k-instruct/ggml-model-f16.gguf -e -p "Hello" -fa -ngl 1
# on master produces garbage without the kernel_cpy_f32_f16 patch
./main -m ./models/refact-1b-fim/ggml-model-f16.gguf -p "# python fibonnaci function:" -e -n 256 -s 1 --temp 0.0 -fa

Bench

./scripts/compare-commits.sh master gg/metal-f16-infs -m models/phi-3-4k-instruct/ggml-model-f16.gguf -m models/gemma-2b/ggml-model-f16.gguf -m models/mistral-instruct-7b-v0.2/ggml-model-f16.gguf -p 1,2,4,8,16,32,64,128,256,512,1024 -t 4 -fa 1
Model Model Size [GiB] Test t/s master t/s gg/metal-f16-infs Speedup
gemma 2B F16 5.64 pp1 98.18 97.11 0.99
gemma 2B F16 5.64 pp2 79.51 78.92 0.99
gemma 2B F16 5.64 pp4 145.48 144.95 1.00
gemma 2B F16 5.64 pp8 289.84 288.99 1.00
gemma 2B F16 5.64 pp16 576.55 585.76 1.02
gemma 2B F16 5.64 pp32 1195.57 1215.90 1.02
gemma 2B F16 5.64 pp64 2096.09 2130.61 1.02
gemma 2B F16 5.64 pp128 3211.25 3255.60 1.01
gemma 2B F16 5.64 pp256 3967.19 4003.99 1.01
gemma 2B F16 5.64 pp512 4268.74 4355.20 1.02
gemma 2B F16 5.64 pp1024 4122.74 4211.68 1.02
gemma 2B F16 5.64 tg128 97.15 97.02 1.00
gemma 2B F16 5.64 pp512+tg128 425.91 426.18 1.00
llama 7B F16 13.49 pp1 40.33 40.33 1.00
llama 7B F16 13.49 pp2 33.33 33.25 1.00
llama 7B F16 13.49 pp4 65.57 65.39 1.00
llama 7B F16 13.49 pp8 130.88 130.29 1.00
llama 7B F16 13.49 pp16 259.91 259.33 1.00
llama 7B F16 13.49 pp32 534.74 533.74 1.00
llama 7B F16 13.49 pp64 856.23 858.09 1.00
llama 7B F16 13.49 pp128 1136.59 1139.71 1.00
llama 7B F16 13.49 pp256 1304.17 1302.65 1.00
llama 7B F16 13.49 pp512 1408.86 1415.65 1.00
llama 7B F16 13.49 pp1024 1381.00 1388.20 1.01
llama 7B F16 13.49 tg128 40.32 40.35 1.00
llama 7B F16 13.49 pp512+tg128 178.23 178.57 1.00
phi3 3B F16 7.12 pp1 58.49 57.85 0.99
phi3 3B F16 7.12 pp2 64.85 64.35 0.99
phi3 3B F16 7.12 pp4 128.49 127.39 0.99
phi3 3B F16 7.12 pp8 255.44 252.81 0.99
phi3 3B F16 7.12 pp16 503.80 501.60 1.00
phi3 3B F16 7.12 pp32 1027.71 1023.41 1.00
phi3 3B F16 7.12 pp64 1559.98 1559.62 1.00
phi3 3B F16 7.12 pp128 2022.55 2031.38 1.00
phi3 3B F16 7.12 pp256 2333.12 2339.13 1.00
phi3 3B F16 7.12 pp512 2439.53 2454.40 1.01
phi3 3B F16 7.12 pp1024 2383.48 2404.82 1.01
phi3 3B F16 7.12 tg128 59.02 58.92 1.00
phi3 3B F16 7.12 pp512+tg128 256.60 256.06 1.00

@slaren
Copy link
Member

slaren commented May 21, 2024

The original issue may have been caused by using fast math.

image

@ggerganov
Copy link
Member Author

I tried disabling it, but it still produced NaNs

@mofosyne mofosyne added Review Complexity : High Generally require indepth knowledge of LLMs or GPUs Vulkan Issues specific to the Vulkan backend and removed Vulkan Issues specific to the Vulkan backend labels May 21, 2024
@ggerganov ggerganov merged commit 6369bf0 into master May 21, 2024
30 checks passed
@ggerganov ggerganov deleted the gg/metal-f16-infs branch May 21, 2024 20:03
teleprint-me pushed a commit to teleprint-me/llama.cpp that referenced this pull request May 23, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Review Complexity : High Generally require indepth knowledge of LLMs or GPUs
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants