Skip to content

cuda : add RoPE kernel for mode == 2 (NeoX) #2760

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Aug 25, 2023
Merged

Conversation

ggerganov
Copy link
Member

@ggerganov ggerganov commented Aug 24, 2023

With this change, running Falcon 7B with -ngl 32 seems to produce the correct results (running ppl tests now).

However, when I offload the KV cache and non-repeating tensors with -ngl 35, the generation is wrong.
Any ideas what could be causing this behavior?

Also another observation is that offloading with -ngl 35 -b 1 works correctly

llama.cpp Outdated
Comment on lines 2759 to 2760
offload_func_nr(cur->src[0]);
offload_func_nr(cur);
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was this intentionally not offloaded, or simply missed? @slaren

Copy link
Member

@slaren slaren Aug 24, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is intentional, since the tensor must be accessible for the CPU (this is the embeddings). It's the same in the llama graph.

@ggerganov
Copy link
Member Author

Looking at the original rope_f32 kernel, it seems that half of the Y block will be doing nothing:

kernel:
https://github.com/ggerganov/llama.cpp/blob/ac4bb6ba02b1e9744ac8c1413f58141d011ce5f5/ggml-cuda.cu#L3887-L3894

launch dims:
https://github.com/ggerganov/llama.cpp/blob/ac4bb6ba02b1e9744ac8c1413f58141d011ce5f5/ggml-cuda.cu#L4802-L4805

Not 100% sure, but either way it might not make any difference if fixed

@slaren
Copy link
Member

slaren commented Aug 24, 2023

I spent quite a bit of time trying to get this to work yesterday, and I came up with an implementation very similar to yours, but it is also not working. So I am starting to think that the problem is elsewhere. Unfortunately, debugging the CUDA backend is a bit of a nightmare.

Current master should already work with 7b with up to 33 offloaded layers, and 40b with up to 60. This change shouldn't be required for that.

Copy link
Collaborator

@JohannesGaessler JohannesGaessler left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll need to look at the CPU code for comparison but one thing that you could try is disabling row flattening (one of the args for ggml_cuda_op) and see if that fixes the issue.

@@ -4799,13 +4798,21 @@ static void scale_f32_cuda(const float * x, float * dst, const float scale, cons

static void rope_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float p0,
const float p_delta, const int p_delta_rows, const float theta_scale, cudaStream_t stream) {
GGML_ASSERT(nrows % 2 == 0);
GGML_ASSERT(nrows % 2 == 0); // GG: is this assert really needed? I don't see why
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I mixed up ncols and nrows here.

@slaren
Copy link
Member

slaren commented Aug 24, 2023

Looking at the original rope_f32 kernel, it seems that half of the Y block will be doing nothing:

I also noticed this, and verified that it is the case. I also have no idea why the block size is doubled. I have been using this on my branch:

static void rope_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float p0,
                          const float p_delta, const int p_delta_rows, const float theta_scale, cudaStream_t stream) {
    GGML_ASSERT(nrows % 2 == 0);
    const dim3 block_dims(1, min(ncols/2, CUDA_ROPE_BLOCK_SIZE), 1);
    const int num_blocks_x = (ncols/2 + CUDA_ROPE_BLOCK_SIZE - 1) / (CUDA_ROPE_BLOCK_SIZE);
    const dim3 block_nums(nrows, num_blocks_x, 1);
    rope_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, p0, p_delta, p_delta_rows, theta_scale);
}

@JohannesGaessler
Copy link
Collaborator

To test the RoPE implementation specifically the following can be done: manually edit the logic for determining whether a particular tensor should be using CUDA so that CUDA is always used for RoPE regardless of backend. Then try running the model with 0 GPU layers. What should be happening is that the CPU is used for everything but RoPE for which data gets copied to the GPU and back again (because it has the CPU backend).

@ggerganov
Copy link
Member Author

ggerganov commented Aug 24, 2023

I spent quite a bit of time trying to get this to work yesterday, and I came up with an implementation very similar to yours, but it is also not working. So I am starting to think that the problem is elsewhere. Unfortunately, debugging the CUDA backend is a bit of a nightmare.

Current master should already work with 7b with up to 33 offloaded layers, and 40b with up to 60. This change shouldn't be required for that.

Hm, how come? Wouldn't CUDA use the non-NeoX RoPE call?

Edit: strange - it does work. I'm confused. I expect it to produce nonsense because the RoPE is wrong

@slaren
Copy link
Member

slaren commented Aug 24, 2023

Hm, how come? Wouldn't CUDA use the non-NeoX RoPE call?

So to be clear, 33 and 60 layers are without KV offloaded, so it runs on the CPU.

@ggerganov
Copy link
Member Author

Ah got it, thanks!

To test the RoPE implementation specifically

I think the NeoX RoPE in this PR is correct because -ngl 35 -b 1 works.

@slaren
Copy link
Member

slaren commented Aug 24, 2023

I think the NeoX RoPE in this PR is correct because -ngl 35 -b 1 works.

I have verified this. 40b also works with full offloading with -b 1.

@ggerganov
Copy link
Member Author

Merging this, and will make an issue to keep track of the problem with the KV cache offload

@ggerganov ggerganov merged commit 3f460a2 into master Aug 25, 2023
@ggerganov ggerganov deleted the fix-falcon-cuda branch August 25, 2023 08:56
akawrykow pushed a commit to akawrykow/llama.cpp that referenced this pull request Aug 29, 2023
* cuda : add RoPE kernel for mode == 2 (NeoX)

* falcon : do not offload the embeddings layer
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants