Skip to content

Vulkan: VK_KHR_cooperative_matrix support to speed up prompt processing #10597

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Dec 7, 2024

Conversation

0cc4m
Copy link
Collaborator

@0cc4m 0cc4m commented Nov 30, 2024

This PR implements basic tensor core/matrix core/xmx engine support using the vendor-neutral VK_KHR_cooperative_matrix Vulkan extension for the matrix multiplication shader. I need to spend some further time optimizing the code, but I would like to get some feedback already.

This is different from #10206 in that it should work on some non-Nvidia GPUs too.

It works on:

  • Nvidia RTX 2000 and newer

It should work on:

  • AMD RX 7000 series and newer
  • Intel ARC

HOWEVER:
It seems AMD's and Intel's drivers like to pretend they support this extension, despite either not having implemented proper hardware support (like my Intel A770 on Linux mesa), or not even having the necessary hardware at all (AMD's proprietary driver for older generations.. why?). So this probably also needs a whitelist feature to avoid regressions on hardware where support isn't actually there.

@github-actions github-actions bot added Vulkan Issues specific to the Vulkan backend ggml changes relating to the ggml tensor library for machine learning labels Nov 30, 2024
@0cc4m
Copy link
Collaborator Author

0cc4m commented Nov 30, 2024

Some preliminary benchmarks, but there's probably some more optimization that can be done.

RTX 3090

model size params backend ngl test t/s master t/s PR t/s CUDA
llama 7B Q4_0 3.83 GiB 7.24 B Vulkan 99 pp512 825.42 ± 1.14 1933.18 ± 9.55 5059.05 ± 11.44
llama 8B Q4_K - Small 4.36 GiB 8.03 B Vulkan 99 pp512 680.42 ± 0.73 1381.26 ± 7.01 4692.44 ± 4.78

Intel just for fun:

Intel A770 (Linux Mesa)

model size params backend ngl test t/s master t/s PR
llama 8B Q4_K - Small 4.36 GiB 8.03 B Vulkan 99 pp512 150.46 ± 0.05 25.48 ± 0.12

@qnixsynapse
Copy link
Collaborator

It seems AMD's and Intel's drivers like to pretend they support this extension, despite either not having implemented proper hardware support (like my Intel A770 on Linux mesa),

This should be reported to Freedesktop Gitlab IMO.

Also, I do test the Vulkan backend on my Intel Arc, I found out that Vulkan backend only uses the shader engines instead of compute engines. The later being faster. Is this intentional?

@0cc4m
Copy link
Collaborator Author

0cc4m commented Nov 30, 2024

It seems AMD's and Intel's drivers like to pretend they support this extension, despite either not having implemented proper hardware support (like my Intel A770 on Linux mesa),

This should be reported to Freedesktop Gitlab IMO.

They know.

Also, I do test the Vulkan backend on my Intel Arc, I found out that Vulkan backend only uses the shader engines instead of compute engines. The later being faster. Is this intentional?

Assuming you mean shader cores and XMX engines, that's exactly what this PR is about. By default Vulkan can run compute shaders only on regular shader cores, it needs extensions like VK_KHR_cooperative_matrix to use the specialized hardware of tensor cores, AMD matrix cores or Intel XMX engines.

Copy link
Collaborator

@jeffbolznv jeffbolznv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't review all the addressing math in detail.

@jeffbolznv
Copy link
Collaborator

Very cool to see this! I tested the changes on RTX 4070, and test-backend-ops is passing, but I don't see much perf gain as-is (like +/- 20% testing a Q4_K model).

I see a couple perf opportunities. First, the branches around the loads for quants like Q4_K can cause the shader to stall inside the if/else waiting for the data. I had to "flatten" the control flow (e.g. https://github.com/ggerganov/llama.cpp/pull/10206/files#diff-508f8330faff804d4067ccb566c1d3eb27a7021a03df4f914c00c977d114514fR181) for the coopmat2 dequant functions, I suspect something similar is needed here.

Second, add [[unroll]] around the store loops. The dynamic indexing of the sums array is problematic and goes away with unrolling.

I tried these locally/hackily and then I get a much better speedup (like 2.5x).

@0cc4m
Copy link
Collaborator Author

0cc4m commented Nov 30, 2024

I see a couple perf opportunities. First, the branches around the loads for quants like Q4_K can cause the shader to stall inside the if/else waiting for the data. I had to "flatten" the control flow (e.g. https://github.com/ggerganov/llama.cpp/pull/10206/files#diff-508f8330faff804d4067ccb566c1d3eb27a7021a03df4f914c00c977d114514fR181) for the coopmat2 dequant functions, I suspect something similar is needed here.

Interesting. So that the GPU has to wait for the load from global memory, but this is worsened by being in a branch?

Second, add [[unroll]] around the store loops. The dynamic indexing of the sums array is problematic and goes away with unrolling.

I tried these locally/hackily and then I get a much better speedup (like 2.5x).

I implemented those changes. I hope you don't mind that I used your loading code. Maybe in the future we can just use your dequant_funcs file for mul_mm too.

I don't see a drastic improvement, but it's 7-8% faster for me on 3090 on Q4_K_S, so that's good. Let me know if you see anything else.

I'm also curious if you learned these GLSL optimization techniques from experience, or if you can recommend some resources I can use to learn. This is my first Vulkan project and the first time going deep into GPU programming, and it's been difficult to find information I can apply.

@jeffbolznv
Copy link
Collaborator

Interesting. So that the GPU has to wait for the load from global memory, but this is worsened by being in a branch?

Yeah, it's challenging for a compiler to track memory accesses across basic blocks, so doing loads inside a branch can cause the value to be waited on inside the branch. And that prevents doing other useful work while the load is pending.

I'm also curious if you learned these GLSL optimization techniques from experience, or if you can recommend some resources I can use to learn. This is my first Vulkan project and the first time going deep into GPU programming, and it's been difficult to find information I can apply.

I think most performance advice is likely to come from vendor specific documentation/tooling, even though a lot of the information and optimization techniques can apply across different hardware. I'd suggest looking at nsight documentation and cuda forums as they'll likely have a lot more low-level information about NVIDIA hardware than you would find in Vulkan-centric documentation. For example, https://developer.nvidia.com/blog/identifying-shader-limiters-with-the-shader-profiler-in-nvidia-nsight-graphics/ talks about using nsight and also touches on this issue of dynamically indexed arrays.

@jeffbolznv
Copy link
Collaborator

I got a device lost error while trying to collect perf results, running Meta-Llama-3-8B-Instruct-Q4_K_S.gguf with pp512. But these two other models run:

master
| llama 7B Q4_0                  |   3.56 GiB |     6.74 B | Vulkan     | 1000 |         pp512 |        831.12  4.03 |
| phi3 3B Q4_K - Medium          |   2.23 GiB |     3.82 B | Vulkan     | 1000 |         pp512 |       1135.54  1.70 |

PR
| llama 7B Q4_0                  |   3.56 GiB |     6.74 B | Vulkan     | 1000 |         pp512 |       1904.89  8.26 |
| phi3 3B Q4_K - Medium          |   2.23 GiB |     3.82 B | Vulkan     | 1000 |         pp512 |      2213.51  25.68 |

@characharm
Copy link
Contributor

Windows OS

this pr
ggml_vulkan: Found 1 Vulkan devices:
ggml_vulkan: 0 = Intel(R) Arc(TM) A770 Graphics (Intel Corporation) | uma: 0 | fp16: 1 | warp size: 32 | matrix cores: 1

model size params backend ngl test t/s
ggml_vulkan: Compiling shaders..............................Done!
qwen2 14B Q5_K - Medium 9.78 GiB 14.77 B Vulkan 99 pp512 20.84 ± 0.03
qwen2 14B Q5_K - Medium 9.78 GiB 14.77 B Vulkan 99 tg128 16.01 ± 0.06

build: 5660976 (4231)

master
ggml_vulkan: Found 1 Vulkan devices:
ggml_vulkan: 0 = Intel(R) Arc(TM) A770 Graphics (Intel Corporation) | uma: 0 | fp16: 1 | warp size: 32

model size params backend ngl test t/s
ggml_vulkan: Compiling shaders..............................Done!
qwen2 ?B Q5_K - Medium 9.78 GiB 14.77 B Vulkan,RPC 99 pp512 67.85 ± 0.52
qwen2 ?B Q5_K - Medium 9.78 GiB 14.77 B Vulkan,RPC 99 tg128 15.52 ± 0.06

build: 25669aa (4179)

@easyfab
Copy link

easyfab commented Dec 1, 2024

Sadly, I have the same regression with A770.

By the way wich is the best mesa version on linux ( docker )?
Because I tried default ( ubuntu 24.04 ) and ppa:kisak/kisak-mesa or oibaf/graphics-drivers.

I was suprised that the default version is faster than newers versions 25 t/s vs 20t/s ? Is this expected ?

Default Mesa 24.0.9-0ubuntu0.2 (LLVM 17.0.6)

MESA-INTEL: warning: cannot initialize blitter engine
ggml_vulkan: Found 1 Vulkan devices:
ggml_vulkan: 0 = Intel(R) Arc(tm) A770 Graphics (DG2) (Intel open-source Mesa driver) | uma: 0 | fp16: 1 | warp size: 32
| model                          |       size |     params | backend    | ngl |          test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | ------------: | -------------------: |
ggml_vulkan: Compiling shaders..............................Done!
| llama 8B Q4_K - Medium         |   4.58 GiB |     8.03 B | Vulkan     |  99 |         pp512 |        154.20 ± 0.02 |
| llama 8B Q4_K - Medium         |   4.58 GiB |     8.03 B | Vulkan     |  99 |         tg128 |         24.99 ± 0.02 |

Mesa 24.3.0 - kisak-mesa PPA (LLVM 17.0.6)

ggml_vulkan: Found 1 Vulkan devices:
ggml_vulkan: 0 = Intel(R) Arc(tm) A770 Graphics (DG2) (Intel open-source Mesa driver) | uma: 0 | fp16: 1 | warp size: 32
| model                          |       size |     params | backend    | ngl |          test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | ------------: | -------------------: |
ggml_vulkan: Compiling shaders..............................Done!
| llama 8B Q4_K - Medium         |   4.58 GiB |     8.03 B | Vulkan     |  99 |         pp512 |        151.10 ± 0.05 |
| llama 8B Q4_K - Medium         |   4.58 GiB |     8.03 B | Vulkan     |  99 |         tg128 |         19.96 ± 0.01 |

@0cc4m
Copy link
Collaborator Author

0cc4m commented Dec 1, 2024

@jeffbolznv I reworked your shared memory shader selection code because I started getting segfaults on Intel that might be related to the growing number of shaders. Now only those that actually get used get compiled, and since Intel only uses the smallest tile size, this reduces the number greatly. This did fix the segfault.

The logic is very different now, maybe you can look over it and check it still makes sense.

@@ -1745,6 +1926,42 @@ static vk_device ggml_vk_get_device(size_t idx) {
ggml_vk_create_queue(device, device->compute_queue, compute_queue_family_index, 0, { vk::PipelineStageFlagBits::eComputeShader | vk::PipelineStageFlagBits::eTransfer }, false);

// Shaders
// Disable matmul tile sizes early if not supported
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this is replacing the old vendor "guess" functions, but what does "if not supported" mean? Is this still needed or does the shared memory calculation replace it?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"Not performant" would be more correct here. The large tiles only run well on Nvidia. The medium tiles only run well on AMD and Nvidia. Intel can only run the small tiles at any decent speed. I think this is related to registers spilling to global memory or overloading caches.

return aligned ? mmp->a_s : mmp->s;
}
if (m <= 64 || n <= 64) {
if ((ctx->device->mul_mat_m && (m <= 64 || n <= 64 || ctx->device->coopmat_support)) || !ctx->device->mul_mat_l) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think coopmat_support should force medium, at least not on NVIDIA.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When it first started working I ran the three tile sizes next to each other and the large tile was very slow. I'll test it again.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not as slow as I remember, but performance drops significantly on RTX 3090 if I allow the large tile.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I narrowed it down a little further: Performance drops significantly with the large tiles when working with quants. With float16 or float32 the large tiles work well with coopmat.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like adding the remaining [[unroll]]s fixes the perf for large tiles.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You... are right. So basically just unroll every loop you can get your hands on?

I guess it's dynamic access to shared memory this time. Why did that affect the large tile so significantly?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's still the sums array, but maybe it was getting unrolled automatically for the medium size but not for the large. The compilers aren't as aggressive about this as they probably should be, so anytime you have a dynamically indexed array where unrolling would make the indices constant, it's usually worth doing (and that's why #pragma unroll shows up literally hundreds of times in the cuda kernels).

@jeffbolznv
Copy link
Collaborator

For the Intel perf issues, my best guesses are to add more [[unroll]] on the coopmatmuladd loops, and maybe use subgroup size control to force 32 invocations per subgroup.

@qnixsynapse
Copy link
Collaborator

qnixsynapse commented Dec 2, 2024

It seems AMD's and Intel's drivers like to pretend they support this extension, despite either not having implemented proper hardware support (like my Intel A770 on Linux mesa),

This should be reported to Freedesktop Gitlab IMO.

They know.

Also, I do test the Vulkan backend on my Intel Arc, I found out that Vulkan backend only uses the shader engines instead of compute engines. The later being faster. Is this intentional?

Assuming you mean shader cores and XMX engines, that's exactly what this PR is about. By default Vulkan can run compute shaders only on regular shader cores, it needs extensions like VK_KHR_cooperative_matrix to use the specialized hardware of tensor cores, AMD matrix cores or Intel XMX engines.

I see.

Well, backend failing on lot of tests:

 MUL_MAT(type_a=f32,type_b=f32,m=16,n=8,k=256,bs=[2,3],nr=[1,1],per=[0,2,1,3]): [MUL_MAT] NaN at index 25 (Vulkan0=-nan CPU=-0.489127) FAIL
  MUL_MAT(type_a=f32,type_b=f32,m=16,n=8,k=256,bs=[2,3],nr=[1,1],per=[0,1,3,2]): [MUL_MAT] NaN at index 98 (Vulkan0=nan CPU=-2.006299) FAIL
  MUL_MAT(type_a=f32,type_b=f32,m=16,n=8,k=256,bs=[2,3],nr=[1,1],per=[0,3,2,1]): [MUL_MAT] NaN at index 43 (Vulkan0=nan CPU=-3.059767) FAIL
  MUL_MAT(type_a=f32,type_b=f32,m=16,n=16,k=256,bs=[2,3],nr=[1,1],per=[0,2,1,3]): [MUL_MAT] NaN at index 17 (Vulkan0=-nan CPU=15.675148) FAIL
  MUL_MAT(type_a=f32,type_b=f32,m=16,n=16,k=256,bs=[2,3],nr=[1,1],per=[0,1,3,2]): [MUL_MAT] NaN at index 42 (Vulkan0=nan CPU=-5.011763) FAIL
  MUL_MAT(type_a=f32,type_b=f32,m=16,n=16,k=256,bs=[2,3],nr=[1,1],per=[0,3,2,1]): [MUL_MAT] NaN at index 27 (Vulkan0=-nan CPU=1.719916) FAIL

MUL_MAT(type_a=f16,type_b=f32,m=16,n=16,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3]): [MUL_MAT] NaN at index 19 (Vulkan0=nan CPU=-2.400129) FAIL
  MUL_MAT(type_a=f16,type_b=f32,m=16,n=16,k=256,bs=[10,1],nr=[1,1],per=[0,1,2,3]): [MUL_MAT] NaN at index 10 (Vulkan0=-nan CPU=-0.972326) FAIL
  MUL_MAT(type_a=f16,type_b=f32,m=16,n=16,k=256,bs=[10,1],nr=[2,1],per=[0,1,2,3]): [MUL_MAT] NaN at index 353 (Vulkan0=-nan CPU=-3.618169) FAIL
  MUL_MAT(type_a=f16,type_b=f32,m=16,n=16,k=256,bs=[10,10],nr=[1,1],per=[0,1,2,3]): [MUL_MAT] NaN at index 58 (Vulkan0=-nan CPU=2.903507) FAIL
  MUL_MAT(type_a=f16,type_b=f32,m=16,n=16,k=256,bs=[10,10],nr=[2,1],per=[0,1,2,3]): [MUL_MAT] NaN at index 73 (Vulkan0=-nan CPU=-0.198813) FAIL

  MUL_MAT(type_a=f16,type_b=f32,m=16,n=8,k=256,bs=[2,3],nr=[1,1],per=[0,2,1,3]): [MUL_MAT] NaN at index 9 (Vulkan0=nan CPU=10.375449) FAIL
  MUL_MAT(type_a=f16,type_b=f32,m=16,n=8,k=256,bs=[2,3],nr=[1,1],per=[0,1,3,2]): [MUL_MAT] NaN at index 107 (Vulkan0=nan CPU=-6.503006) FAIL
  MUL_MAT(type_a=f16,type_b=f32,m=16,n=8,k=256,bs=[2,3],nr=[1,1],per=[0,3,2,1]): [MUL_MAT] NaN at index 203 (Vulkan0=-nan CPU=-2.009472) FAIL
  MUL_MAT(type_a=f16,type_b=f32,m=16,n=16,k=256,bs=[2,3],nr=[1,1],per=[0,2,1,3]): [MUL_MAT] NaN at index 0 (Vulkan0=nan CPU=3.655150) FAIL
  MUL_MAT(type_a=f16,type_b=f32,m=16,n=16,k=256,bs=[2,3],nr=[1,1],per=[0,1,3,2]): [MUL_MAT] NaN at index 81 (Vulkan0=-nan CPU=-0.457537) FAIL
  MUL_MAT(type_a=f16,type_b=f32,m=16,n=16,k=256,bs=[2,3],nr=[1,1],per=[0,3,2,1]): [MUL_MAT] NaN at index 128 (Vulkan0=-nan CPU=7.094562) FAIL

 MUL_MAT(type_a=f16,type_b=f16,m=16,n=16,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3]): [MUL_MAT] NMSE = 1.000019055 > 0.000500000 FAIL
  MUL_MAT(type_a=f16,type_b=f16,m=16,n=16,k=256,bs=[10,1],nr=[1,1],per=[0,1,2,3]): [MUL_MAT] NaN at index 161 (Vulkan0=-nan CPU=-0.011188) FAIL
  MUL_MAT(type_a=f16,type_b=f16,m=16,n=16,k=256,bs=[10,1],nr=[2,1],per=[0,1,2,3]): [MUL_MAT] NaN at index 51 (Vulkan0=nan CPU=-4.520026) FAIL
  MUL_MAT(type_a=f16,type_b=f16,m=16,n=16,k=256,bs=[10,10],nr=[1,1],per=[0,1,2,3]): [MUL_MAT] NaN at index 8 (Vulkan0=-nan CPU=9.592922) FAIL
  MUL_MAT(type_a=f16,type_b=f16,m=16,n=16,k=256,bs=[10,10],nr=[2,1],per=[0,1,2,3]): [MUL_MAT] NaN at index 106 (Vulkan0=nan CPU=-2.529743) FAIL

 MUL_MAT(type_a=f16,type_b=f16,m=16,n=8,k=256,bs=[2,3],nr=[1,1],per=[0,2,1,3]): [MUL_MAT] NaN at index 224 (Vulkan0=nan CPU=-4.796510) FAIL
  MUL_MAT(type_a=f16,type_b=f16,m=16,n=8,k=256,bs=[2,3],nr=[1,1],per=[0,1,3,2]): [MUL_MAT] NaN at index 57 (Vulkan0=nan CPU=-0.507002) FAIL
  MUL_MAT(type_a=f16,type_b=f16,m=16,n=8,k=256,bs=[2,3],nr=[1,1],per=[0,3,2,1]): [MUL_MAT] NaN at index 90 (Vulkan0=-nan CPU=-3.574138) FAIL
  MUL_MAT(type_a=f16,type_b=f16,m=16,n=16,k=256,bs=[2,3],nr=[1,1],per=[0,2,1,3]): [MUL_MAT] NaN at index 17 (Vulkan0=nan CPU=3.233534) FAIL
  MUL_MAT(type_a=f16,type_b=f16,m=16,n=16,k=256,bs=[2,3],nr=[1,1],per=[0,1,3,2]): [MUL_MAT] NaN at index 24 (Vulkan0=-nan CPU=2.170926) FAIL
  MUL_MAT(type_a=f16,type_b=f16,m=16,n=16,k=256,bs=[2,3],nr=[1,1],per=[0,3,2,1]): [MUL_MAT] NaN at index 305 (Vulkan0=-nan CPU=3.681449) FAIL

 MUL_MAT(type_a=q8_0,type_b=f32,m=16,n=16,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3]): [MUL_MAT] NaN at index 10 (Vulkan0=nan CPU=12.519882) FAIL
  MUL_MAT(type_a=q8_0,type_b=f32,m=16,n=16,k=256,bs=[10,1],nr=[1,1],per=[0,1,2,3]): [MUL_MAT] NaN at index 2315 (Vulkan0=-nan CPU=11.811843) FAIL
  MUL_MAT(type_a=q8_0,type_b=f32,m=16,n=16,k=256,bs=[10,1],nr=[2,1],per=[0,1,2,3]): [MUL_MAT] NaN at index 4867 (Vulkan0=-nan CPU=-1.555186) FAIL
  MUL_MAT(type_a=q8_0,type_b=f32,m=16,n=16,k=256,bs=[10,10],nr=[1,1],per=[0,1,2,3]): [MUL_MAT] NaN at index 16179 (Vulkan0=-nan CPU=-1.033244) FAIL
  MUL_MAT(type_a=q8_0,type_b=f32,m=16,n=16,k=256,bs=[10,10],nr=[2,1],per=[0,1,2,3]): [MUL_MAT] NaN at index 3040 (Vulkan0=-nan CPU=2.866659) FAIL

 MUL_MAT(type_a=q4_0,type_b=f32,m=16,n=16,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3]): [MUL_MAT] NaN at index 67 (Vulkan0=nan CPU=7.281193) FAIL
  MUL_MAT(type_a=q4_0,type_b=f32,m=16,n=16,k=256,bs=[10,1],nr=[1,1],per=[0,1,2,3]): [MUL_MAT] NaN at index 825 (Vulkan0=nan CPU=-6.018594) FAIL
  MUL_MAT(type_a=q4_0,type_b=f32,m=16,n=16,k=256,bs=[10,1],nr=[2,1],per=[0,1,2,3]): [MUL_MAT] NaN at index 1920 (Vulkan0=-nan CPU=6.858445) FAIL
  MUL_MAT(type_a=q4_0,type_b=f32,m=16,n=16,k=256,bs=[10,10],nr=[1,1],per=[0,1,2,3]): [MUL_MAT] NaN at index 9058 (Vulkan0=nan CPU=14.321257) FAIL
  MUL_MAT(type_a=q4_0,type_b=f32,m=16,n=16,k=256,bs=[10,10],nr=[2,1],per=[0,1,2,3]): [MUL_MAT] NaN at index 792 (Vulkan0=-nan CPU=-3.903759) FAIL

  MUL_MAT(type_a=q4_1,type_b=f32,m=16,n=16,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3]): [MUL_MAT] NaN at index 0 (Vulkan0=-nan CPU=6.785239) FAIL
  MUL_MAT(type_a=q4_1,type_b=f32,m=16,n=16,k=256,bs=[10,1],nr=[1,1],per=[0,1,2,3]): [MUL_MAT] NaN at index 2433 (Vulkan0=nan CPU=-2.455450) FAIL
  MUL_MAT(type_a=q4_1,type_b=f32,m=16,n=16,k=256,bs=[10,1],nr=[2,1],per=[0,1,2,3]): [MUL_MAT] NaN at index 3928 (Vulkan0=nan CPU=-0.206156) FAIL
  MUL_MAT(type_a=q4_1,type_b=f32,m=16,n=16,k=256,bs=[10,10],nr=[1,1],per=[0,1,2,3]): [MUL_MAT] NaN at index 2928 (Vulkan0=-nan CPU=-1.648083) FAIL
  MUL_MAT(type_a=q4_1,type_b=f32,m=16,n=16,k=256,bs=[10,10],nr=[2,1],per=[0,1,2,3]): [MUL_MAT] NaN at index 2418 (Vulkan0=nan CPU=-2.766699) FAIL

 MUL_MAT(type_a=q4_K,type_b=f32,m=16,n=16,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3]): terminate called after throwing an instance of 'vk::DeviceLostError'
  what():  vk::Device::waitForFences: ErrorDeviceLost
Aborted (core dumped)

Still uses the 3D shader engines for computation with this PR.
GPU: Intel Arc A750

@0cc4m
Copy link
Collaborator Author

0cc4m commented Dec 2, 2024

For the Intel perf issues, my best guesses are to add more [[unroll]] on the coopmatmuladd loops, and maybe use subgroup size control to force 32 invocations per subgroup.

At least on Linux, Mesa's ANV Intel driver currently just emulates the coopmat extension on the shader cores in an inefficient way, until proper support is added. I don't know what the WIndows driver is doing, but I assume something similar.

Before this PR is ready to merge, I'll have to implement a driver/device whitelist to make sure the coopmat codepath only happens (by default) on devices that actually support it properly.

Well, backend failing on lot of tests:
[...]
Still uses the 3D shader engines for computation with this PR. GPU: Intel Arc A750

Thank you for reporting the failing tests on Windows.

It will always use the shader engines as well, the XMX engines are specialized and can only do some specific operations. But it seems the Intel Windows driver also hasn't implemented coopmat support properly. I assume if you can see whether an application uses shader or compute engines, and this PR still doesn't use the compute engines, then Intel on Windows is also just emulating support.

@qnixsynapse
Copy link
Collaborator

It will always use the shader engines as well, the XMX engines are specialized and can only do some specific operations. But it seems the Intel Windows driver also hasn't implemented coopmat support properly. I assume if you can see whether an application uses shader or compute engines, and this PR still doesn't use the compute engines, then Intel on Windows is also just emulating support.

I'm sorry for not mentioning this earlier. It's mesa ANV driver on Arch Linux.

@0cc4m
Copy link
Collaborator Author

0cc4m commented Dec 2, 2024

I'm sorry for not mentioning this earlier. It's mesa ANV driver on Arch Linux.

Oh, interesting. So you're monitoring through intel_gpu_top. I assume it just shows this backend under Render/3D because it's Vulkan, not OpenCL/LevelZero. There's nothing I can do about that.

However, on my A770 under Linux Mesa 24.0.9-0ubuntu0.2, the tests do pass. Performance is low, but the results are correct. Which version are you using?

@qnixsynapse
Copy link
Collaborator

I am using Mesa version: 24.2.7, Latest mainline Linux kernel. Everything is up to date.

I assume it just shows this backend under Render/3D because it's Vulkan, not OpenCL/LevelZero. There's nothing I can do about that.

Yeah. I think I need to have a talk upstream. Reason why I am testing this PR.

@characharm
Copy link
Contributor

image
Vulkan has been loading Compute Units for quite some time now, but unlike SYCL, Windows Task Manager also reports some load on 3D

@jeffbolznv
Copy link
Collaborator

I got a device lost error while trying to collect perf results, running Meta-Llama-3-8B-Instruct-Q4_K_S.gguf with pp512.

I no longer see this with the latest code. test-backend-ops passes, I touch tested a few models and perf is quite good and they seem to be working correctly. So things seem to be in pretty good shape.

What else do you think remains to make this no longer "Draft"?

@0cc4m
Copy link
Collaborator Author

0cc4m commented Dec 3, 2024

I got a device lost error while trying to collect perf results, running Meta-Llama-3-8B-Instruct-Q4_K_S.gguf with pp512.

I no longer see this with the latest code. test-backend-ops passes, I touch tested a few models and perf is quite good and they seem to be working correctly. So things seem to be in pretty good shape.

What else do you think remains to make this no longer "Draft"?

In the current state I'm pretty sure this PR will tank the performance of most AMD GPUs on Windows, since that driver erroneously reports support for coopmats on every GPU. I'm gonna have to figure out a way to tell which devices have proper support (RX 7000 series) and which are just pretending.

@0cc4m
Copy link
Collaborator Author

0cc4m commented Dec 4, 2024

It's even worse than I expected. I tested it with an RX 6800 XT on Windows and it advertises support, but outright crashes when compiling a coopmat shader. On RX 7900 XTX it runs, but gives incorrect results, so I'll disable coopmats for AMD on Windows for now. Linux Mesa should work, I hope someone can test it.

@0cc4m
Copy link
Collaborator Author

0cc4m commented Dec 4, 2024

It was now tested on Linux Mesa with an AMD RX 7900 XTX and works fine. So the result is that the feature will be available for Nvidia (Turing+) on Windows and Linux and for AMD RX 7000 series on Linux. 7000 series on Windows and Intel ARC on Linux and Windows will need proper driver-side implementations before they can utilize coopmats.

@0cc4m 0cc4m marked this pull request as ready for review December 4, 2024 16:05
@jeffbolznv
Copy link
Collaborator

I ran out of time today, I'll review this tomorrow.

@0cc4m 0cc4m force-pushed the 0cc4m/vulkan-coopmat branch from b654968 to f54afb4 Compare December 5, 2024 20:47
@0cc4m
Copy link
Collaborator Author

0cc4m commented Dec 5, 2024

I think this is basically ready. I still want to check if converting the B matrices to fp16 would improve performance of the coopmat shader, but that's not urgent.

@jeffbolznv Please check if I messed up any coopmat2 code while merging. I checked that it still works, at least.

@jeffbolznv
Copy link
Collaborator

I did some testing, and everything looks good with #10597 (comment) fixed. I'll +1 it now, please go ahead and merge after fixing that.

@0cc4m
Copy link
Collaborator Author

0cc4m commented Dec 6, 2024

I forgot to apply the coopmat2 shader selection logic to the mul_mat_id selection function. I'll do that soon, afterwards I'll merge.

@mmerecki
Copy link

mmerecki commented Dec 6, 2024

I run the backend test (test-backend-ops.exe -b Vulkan0) on ARC A750, on Windows, using driver ver. 32.0.101.6314.
All tests passed (test-backend-ops.log).
Is test-backend-ops test the one that was failing on Intel Windows?

@0cc4m
Copy link
Collaborator Author

0cc4m commented Dec 6, 2024

I run the backend test (test-backend-ops.exe -b Vulkan0) on ARC A750, on Windows, using driver ver. 32.0.101.6314. All tests passed (test-backend-ops.log). Is test-backend-ops test the one that was failing on Intel Windows?

The feature was disabled on Intel for now, so the tests pass and performance is the same as on master.

@mmerecki
Copy link

mmerecki commented Dec 6, 2024

The feature was disabled on Intel for now, so the tests pass and performance is the same as on master.

Yes, I noticed that and enabled it back again for the test. I might have not used the correct binary though.
I will check again to see if these issues are still present with newer Windows drivers. Thank you for this PR.

@airlied
Copy link
Contributor

airlied commented Dec 6, 2024

So I think in theory the A770 should have the hw instructions hooked up in the mesa driver, they emulate it on hw where it doesn't and Intel claimed it might be more optimal than having manually doing it, but I doubt this is always true.

https://gitlab.freedesktop.org/mesa/mesa/-/blob/main/src/intel/compiler/brw_compiler.c?ref_type=heads#L104 is the dpas instruction lowering for hw that doesn't have it.

@0cc4m
Copy link
Collaborator Author

0cc4m commented Dec 7, 2024

So I think in theory the A770 should have the hw instructions hooked up in the mesa driver, they emulate it on hw where it doesn't and Intel claimed it might be more optimal than having manually doing it, but I doubt this is always true.

https://gitlab.freedesktop.org/mesa/mesa/-/blob/main/src/intel/compiler/brw_compiler.c?ref_type=heads#L104 is the dpas instruction lowering for hw that doesn't have it.

Here's the issue tracking proper VK_KHR_cooperative_matrix support: https://gitlab.freedesktop.org/mesa/mesa/-/issues/9250

@oscarbg
Copy link
Contributor

oscarbg commented Dec 7, 2024

Hi,
sorry for asking here.. but once this is merged I plan to run benches on NV with Vulkan dev driver supporting coop matrix 2 and want to select running coop matrix 1, coop matrix 2 and without coop matrix paths, any enviroment variable or command line argument to forcely select each of these 3 "backends"?
thanks..

@0cc4m
Copy link
Collaborator Author

0cc4m commented Dec 7, 2024

Hi, sorry for asking here.. but once this is merged I plan to run benches on NV with Vulkan dev driver supporting coop matrix 2 and want to select running coop matrix 1, coop matrix 2 and without coop matrix paths, any enviroment variable or command line argument to forcely select each of these 3 "backends"? thanks..

Yes, we have provided environment variables to disable coopmat and coopmat2 support if needed. They are GGML_VK_DISABLE_COOPMAT and GGML_VK_DISABLE_COOPMAT2.

I have used them to generate these graphs:
coopmat_rtx3090_pp
coopmat_rtx3090_tg

@0cc4m 0cc4m force-pushed the 0cc4m/vulkan-coopmat branch from 297f5ca to eeaf0b9 Compare December 7, 2024 07:41
@airlied
Copy link
Contributor

airlied commented Dec 7, 2024

So I think in theory the A770 should have the hw instructions hooked up in the mesa driver, they emulate it on hw where it doesn't and Intel claimed it might be more optimal than having manually doing it, but I doubt this is always true.
https://gitlab.freedesktop.org/mesa/mesa/-/blob/main/src/intel/compiler/brw_compiler.c?ref_type=heads#L104 is the dpas instruction lowering for hw that doesn't have it.

Here's the issue tracking proper VK_KHR_cooperative_matrix support: https://gitlab.freedesktop.org/mesa/mesa/-/issues/9250

Pretty sure nearly all that work was completed and upstream nearly a year ago, I played around with coop matrix on llama.cpp and a770 a while back, and ran into some of this and talked to a few Intel folks on irc, but I ended up getting distracted.

@0cc4m 0cc4m merged commit 3df784b into master Dec 7, 2024
44 checks passed
@0cc4m 0cc4m deleted the 0cc4m/vulkan-coopmat branch December 7, 2024 09:24
@oscarbg
Copy link
Contributor

oscarbg commented Dec 7, 2024

GGML_VK_DISABLE_COOPMAT

many thanks!!

@jeffbolznv
Copy link
Collaborator

@oscarbg we'll release an updated developer driver soon (next week, hopefully) with some additional optimizations for coopmat2, so it might make sense to wait for that.

arthw pushed a commit to arthw/llama.cpp that referenced this pull request Dec 20, 2024
…ng (ggml-org#10597)

* Vulkan: Implement VK_KHR_cooperative_matrix support in the matrix matrix multiplication shader

* Improve performance with better q4_k and q5_k dequant and store unrolling

* Add Vulkan MUL_MAT and MUL_MAT_ID accumulator precision selection

* Rework mulmat shader selection and compilation logic, avoid compiling shaders that won't get used by device

* Vulkan: Implement accumulator switch for specific mul mat mat shaders

* Vulkan: Unroll more loops for more mul mat mat performance

* Vulkan: Add VK_AMD_shader_core_properties2 support to read Compute Unit count for split_k logic

* Disable coopmat support on AMD proprietary driver

* Remove redundant checks

* Add environment variable GGML_VK_DISABLE_COOPMAT to disable VK_KHR_cooperative_matrix support

* Fix rebase typo

* Fix coopmat2 MUL_MAT_ID pipeline selection
tinglou pushed a commit to tinglou/llama.cpp that referenced this pull request Feb 13, 2025
…ng (ggml-org#10597)

* Vulkan: Implement VK_KHR_cooperative_matrix support in the matrix matrix multiplication shader

* Improve performance with better q4_k and q5_k dequant and store unrolling

* Add Vulkan MUL_MAT and MUL_MAT_ID accumulator precision selection

* Rework mulmat shader selection and compilation logic, avoid compiling shaders that won't get used by device

* Vulkan: Implement accumulator switch for specific mul mat mat shaders

* Vulkan: Unroll more loops for more mul mat mat performance

* Vulkan: Add VK_AMD_shader_core_properties2 support to read Compute Unit count for split_k logic

* Disable coopmat support on AMD proprietary driver

* Remove redundant checks

* Add environment variable GGML_VK_DISABLE_COOPMAT to disable VK_KHR_cooperative_matrix support

* Fix rebase typo

* Fix coopmat2 MUL_MAT_ID pipeline selection
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ggml changes relating to the ggml tensor library for machine learning Vulkan Issues specific to the Vulkan backend
Projects
None yet
Development

Successfully merging this pull request may close these issues.

9 participants