Skip to content

Avoid fp32->fp16->fp32 conversion on cdna in ggml_cuda_op_mul_mat_cublas #11356

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 24, 2025

Conversation

IMbackK
Copy link
Collaborator

@IMbackK IMbackK commented Jan 22, 2025

This further improves on #10498 by removeing the fp32->fp16->fp32 conversion on cdna in ggml_cuda_op_mul_mat_cublas. Unlike what is stated in #10498 this actually dose improve performance, as the issue fixed by #11244 was simply hiding the change. the issue fixed by #11244 was also hideing a pessimisation in #10498 which this pr also reverts.

Master:

  Device 0: AMD Instinct MI100, compute capability 9.0, VMM: no
| model                          |       size |     params | backend    | ngl | n_batch |          test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | ------: | ------------: | -------------------: |
| llama 8B Q4_K - Medium         |   4.58 GiB |     8.03 B | ROCm       |  99 |      64 |          pp64 |        434.23 ± 0.43 |
| llama 8B Q4_K - Medium         |   4.58 GiB |     8.03 B | ROCm       |  99 |      64 |         pp128 |        433.79 ± 0.39 |
| llama 8B Q4_K - Medium         |   4.58 GiB |     8.03 B | ROCm       |  99 |      64 |         pp512 |        430.21 ± 0.28 |
| llama 8B Q4_K - Medium         |   4.58 GiB |     8.03 B | ROCm       |  99 |     512 |          pp64 |        433.94 ± 0.20 |
| llama 8B Q4_K - Medium         |   4.58 GiB |     8.03 B | ROCm       |  99 |     512 |         pp128 |        785.42 ± 0.49 |
| llama 8B Q4_K - Medium         |   4.58 GiB |     8.03 B | ROCm       |  99 |     512 |         pp512 |       1108.48 ± 0.63 |
| llama 8B Q4_K - Medium         |   4.58 GiB |     8.03 B | ROCm       |  99 |    1024 |          pp64 |        433.87 ± 0.41 |
| llama 8B Q4_K - Medium         |   4.58 GiB |     8.03 B | ROCm       |  99 |    1024 |         pp128 |        783.96 ± 1.65 |
| llama 8B Q4_K - Medium         |   4.58 GiB |     8.03 B | ROCm       |  99 |    1024 |         pp512 |       1106.24 ± 0.35 |
| llama 8B Q4_K - Medium         |   4.58 GiB |     8.03 B | ROCm       |  99 |    2048 |          pp64 |        432.29 ± 0.85 |
| llama 8B Q4_K - Medium         |   4.58 GiB |     8.03 B | ROCm       |  99 |    2048 |         pp128 |        783.59 ± 0.92 |
| llama 8B Q4_K - Medium         |   4.58 GiB |     8.03 B | ROCm       |  99 |    2048 |         pp512 |       1104.26 ± 0.70 |
  Device 0: AMD Instinct MI100, compute capability 9.0, VMM: no
  Device 0: AMD Instinct MI100, compute capability 9.0, VMM: no
| model                          |       size |     params | backend    | ngl |    sm |          test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | ----: | ------------: | -------------------: |
| llama 70B Q4_K - Medium        |  39.59 GiB |    70.55 B | ROCm       |  99 |   row |         pp512 |        137.70 ± 5.21 |

This pr + #11244:

  Device 0: AMD Instinct MI100, gfx908:sramecc+:xnack- (0x908), VMM: no
| model                          |       size |     params | backend    | ngl | n_batch |          test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | ------: | ------------: | -------------------: |
| llama 8B Q4_K - Medium         |   4.58 GiB |     8.03 B | ROCm       |  99 |      64 |          pp64 |        697.88 ± 2.07 |
| llama 8B Q4_K - Medium         |   4.58 GiB |     8.03 B | ROCm       |  99 |      64 |         pp128 |        698.75 ± 1.16 |
| llama 8B Q4_K - Medium         |   4.58 GiB |     8.03 B | ROCm       |  99 |      64 |         pp512 |        689.11 ± 0.74 |
| llama 8B Q4_K - Medium         |   4.58 GiB |     8.03 B | ROCm       |  99 |      64 |         tg128 |         90.13 ± 0.48 |
| llama 8B Q4_K - Medium         |   4.58 GiB |     8.03 B | ROCm       |  99 |     512 |          pp64 |        693.25 ± 3.00 |
| llama 8B Q4_K - Medium         |   4.58 GiB |     8.03 B | ROCm       |  99 |     512 |         pp128 |       1410.06 ± 2.89 |
| llama 8B Q4_K - Medium         |   4.58 GiB |     8.03 B | ROCm       |  99 |     512 |         pp512 |       2967.72 ± 4.00 |
| llama 8B Q4_K - Medium         |   4.58 GiB |     8.03 B | ROCm       |  99 |     512 |         tg128 |         90.14 ± 0.46 |
| llama 8B Q4_K - Medium         |   4.58 GiB |     8.03 B | ROCm       |  99 |    1024 |          pp64 |        688.72 ± 2.51 |
| llama 8B Q4_K - Medium         |   4.58 GiB |     8.03 B | ROCm       |  99 |    1024 |         pp128 |       1407.38 ± 2.64 |
| llama 8B Q4_K - Medium         |   4.58 GiB |     8.03 B | ROCm       |  99 |    1024 |         pp512 |       2937.60 ± 5.10 |
| llama 8B Q4_K - Medium         |   4.58 GiB |     8.03 B | ROCm       |  99 |    1024 |         tg128 |         89.14 ± 0.14 |
| llama 8B Q4_K - Medium         |   4.58 GiB |     8.03 B | ROCm       |  99 |    2048 |          pp64 |        681.53 ± 1.11 |
| llama 8B Q4_K - Medium         |   4.58 GiB |     8.03 B | ROCm       |  99 |    2048 |         pp128 |       1392.21 ± 3.51 |
| llama 8B Q4_K - Medium         |   4.58 GiB |     8.03 B | ROCm       |  99 |    2048 |         pp512 |       2911.77 ± 2.64 |
| llama 8B Q4_K - Medium         |   4.58 GiB |     8.03 B | ROCm       |  99 |    2048 |         tg128 |         88.68 ± 0.67 |
  Device 0: AMD Instinct MI100, gfx908:sramecc+:xnack- (0x908), VMM: no
  Device 1: AMD Instinct MI100, gfx908:sramecc+:xnack- (0x908), VMM: no
| model                          |       size |     params | backend    | ngl |          test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | ------------: | -------------------: |
| llama 70B Q4_K - Medium        |  39.59 GiB |    70.55 B | ROCm       |  99 |         pp512 |        371.24 ± 0.12 |

@github-actions github-actions bot added Nvidia GPU Issues specific to Nvidia GPUs ggml changes relating to the ggml tensor library for machine learning labels Jan 22, 2025
@IMbackK
Copy link
Collaborator Author

IMbackK commented Jan 22, 2025

I fear any further improvements will require extensive changes to mmq and the use of __builtin_amdgcn_mfma_I32_32x32x8I8

Copy link
Collaborator

@JohannesGaessler JohannesGaessler left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would be happy to cooperate on changes to MMQ but my expectation is that it will not at all be easy. The performance is very sensitive to the exact memory layout due to shared memory bank conflicts so it may be better to just write an explicit ROCm kernel. The correct way to utilize the AMD equivalent of tensor cores via HIP would be to modify mma.cuh (unless this interface for some reason cannot be used). I am currently working on language model training and have a working CPU/CUDA prototype. If this is of interest of you it would make sense to at some point investigate similar performance optimizations for OUT_PROD as you did for MUL_MAT in this PR. Also one of my next goals will be to re-write the FlashAttention code to use primitives like in mma.cuh instead of nvcuda::wmma. So if you are interested in AMD support it may make sense to check ahead of time whether the mma.cuh interface will need to be adjusted.

@IMbackK
Copy link
Collaborator Author

IMbackK commented Jan 24, 2025

I would be happy to cooperate on changes to MMQ but my expectation is that it will not at all be easy. The performance is very sensitive to the exact memory layout due to shared memory bank conflicts so it may be better to just write an explicit ROCm kernel. The correct way to utilize the AMD equivalent of tensor cores via HIP would be to modify mma.cuh (unless this interface for some reason cannot be used). I am currently working on language model training and have a working CPU/CUDA prototype. If this is of interest of you it would make sense to at some point investigate similar performance optimizations for OUT_PROD as you did for MUL_MAT in this PR. Also one of my next goals will be to re-write the FlashAttention code to use primitives like in mma.cuh instead of nvcuda::wmma. So if you are interested in AMD support it may make sense to check ahead of time whether the mma.cuh interface will need to be adjusted.

Yeah i know, i think eventually we would have to effectively split the cuda and hip backed but i dont expect i will ever have the bandwidth to maintain the result, so for now we need to keep things as is. Annoyingly for best performance you would need 2 sets of kernels even for hip since the difference in performance characteristics of rdna vs cdna/gcn gpus is pretty big.

I will try to look into mma soon from the perspective of cdna.

@JohannesGaessler
Copy link
Collaborator

Right now I only defined int8 primitives in mma.cuh but the same concepts would apply for FP16. Basically I defined matrix tiles with dimensions I, J, and K and functions get_i, get_j, and get_k to get the indices of threads within the tiles. Tiles have either 1 or 2 32 bit integer values per thread in a warp. The big advantage over wmma is that you now have a defined data layout and don't have to go through shared memory.

@IMbackK
Copy link
Collaborator Author

IMbackK commented Jan 24, 2025

@JohannesGaessler i fixed the nit.

Btw to you know offhand (have not tried profileing this yet) where this discrepancy comes from:

$ ./bin/llama-server -ngl 99 -m ~/machine-lerning/Models/llms/GGUF/Llama-3.3-70B-Instruct-Q4_K_M.gguf -c 8192
prompt eval time =    2695.92 ms /   727 tokens (    3.71 ms per token,   269.67 tokens per second)
       eval time =    2278.07 ms /    31 tokens (   73.49 ms per token,    13.61 tokens per second)
      total time =    4973.99 ms /   758 tokens
srv  update_slots: all slots are idle
request: POST /v1/chat/completions 127.0.0.1 200


$ ./bin/llama-bench -m /home/philipp/machine-lerning/Models/llms/GGUF/Llama-3.3-70B-Instruct-Q4_K_M.gguf -p 722 -ngl 99 -fa 0 -n 31
ggml_cuda_init: GGML_CUDA_FORCE_MMQ:    no
ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
ggml_cuda_init: found 2 ROCm devices:
  Device 0: AMD Instinct MI100, gfx908:sramecc+:xnack- (0x908), VMM: yes
  Device 1: AMD Instinct MI100, gfx908:sramecc+:xnack- (0x908), VMM: yes
| model                          |       size |     params | backend    | ngl |          test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | ------------: | -------------------: |
| llama 70B Q4_K - Medium        |  39.59 GiB |    70.55 B | ROCm       |  99 |         pp727 |        403.01 ± 0.52 |
| llama 70B Q4_K - Medium        |  39.59 GiB |    70.55 B | ROCm       |  99 |          tg31 |         14.56 ± 0.03 |

Feals a bit large to me.

@JohannesGaessler
Copy link
Collaborator

I could be misremembering but I think the server had a comparatively small default physical batch size of 128. If you are using cuBLAS/rocBLAS GEMM that imposes a large overhead for dequantizing the weight matrices to FP16/FP32. You can check this by comparing performance for an FP32 model.

@JohannesGaessler JohannesGaessler merged commit 9fbadae into ggml-org:master Jan 24, 2025
45 checks passed
@JohannesGaessler
Copy link
Collaborator

This PR broke FP16 GEMM, fixed by #11396 .

anagri pushed a commit to BodhiSearch/llama.cpp that referenced this pull request Jan 26, 2025
Nexesenex added a commit to Nexesenex/croco.cpp that referenced this pull request Jan 26, 2025
tinglou pushed a commit to tinglou/llama.cpp that referenced this pull request Feb 13, 2025
arthw pushed a commit to arthw/llama.cpp that referenced this pull request Feb 26, 2025
mglambda pushed a commit to mglambda/llama.cpp that referenced this pull request Mar 8, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ggml changes relating to the ggml tensor library for machine learning Nvidia GPU Issues specific to Nvidia GPUs
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants