Skip to content

CUDA: fix RoPE asserts, block sizes #2833

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 28, 2023

Conversation

JohannesGaessler
Copy link
Collaborator

This PR fixes RoPE CUDA block sizes being too large by a factor of 2 which causes the kernel to do redundant work (I was not able to measure a performance difference however). I do not have a model utilizing GLM RoPE ready so I was not able to test the corresponding fix. @li-plus can you please check this?

This PR also fixes the asserts which are supposed to ensure that the number of values per row are even so that each thread can safely access 2 values.

@slaren
Copy link
Member

slaren commented Aug 27, 2023

Shouldn't num_blocks_x also be changed? Such as:

const int num_blocks_x = (ncols/2 + CUDA_ROPE_BLOCK_SIZE - 1) / CUDA_ROPE_BLOCK_SIZE;

@JohannesGaessler
Copy link
Collaborator Author

No because each thread accesses 2 values so with this ratio each value gets accessed exactly once. In other news, the issue seems to have been fixed independently for GLM RoPE via ggml-org/ggml#477 so I will remove my corresponding changes.

@ggerganov
Copy link
Member

ggerganov commented Aug 27, 2023

No because each thread accesses 2 values so with this ratio each value gets accessed exactly once.

Hm, I think @slaren is right. Consider CUDA_ROPE_BLOCK_SIZE == 1 for simplicity.
This will be noop for all threads with index > ncols/2:

https://github.com/ggerganov/llama.cpp/blob/230d46c723edf5999752e4cb67fd94edb19ef9c7/ggml-cuda.cu#L3994-L4002

Nvm: I just realized you've halved the block size. Should be good

@li-plus
Copy link
Contributor

li-plus commented Aug 28, 2023

@JohannesGaessler I've tested chatglm2 which uses rope_f32_cuda with this PR. It works fine. I don't have a model that uses rope_neox_f32_cuda but it should work as well.

One question is why columns are now indexed by threadIdx.y instead of threadIdx.x? I guess this will hurt performance since memory access pattern is less friendly.

Update: Just did an experiment on switching threadIdx.x & y. There is no differences in kernel time. I think it will be fine for this type of workload.

@JohannesGaessler
Copy link
Collaborator Author

The reason why x and y were swapped is that the grid size in y direction is limited to 65535 while in x direction it's limited to 4294967295. This allows for a higher maximum batch size.

@ggerganov ggerganov merged commit 92b1bbd into ggml-org:master Aug 28, 2023
akawrykow pushed a commit to akawrykow/llama.cpp that referenced this pull request Aug 29, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants