Skip to content

Fix embedding when embedding layer on GPU #1873

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from

Conversation

howard0su
Copy link
Collaborator

Support copying data from GPU back to HOST

@howard0su
Copy link
Collaborator Author

OpenCL is not needed for now but it will be useful eventually.

@JohannesGaessler
Copy link
Collaborator

When is the embedding layer ever on the GPU?

@howard0su
Copy link
Collaborator Author

when we creates KV in GPU.

@JohannesGaessler
Copy link
Collaborator

I don't see how that is going to happen. Please provide a concrete set of parameters with which the embedding layer ends up on the GPU.

@LLukas22
Copy link

I agree that the embedding layer probably isn't being offloaded to the GPU right now. But having the option to move memory back to the host could be a handy feature, especially if we're thinking about bringing these changes back into the GGML repo. This would definitely give us more freedom in terms of offloading and acceleration.

@JohannesGaessler
Copy link
Collaborator

To keep the code simpler, I think this should only be merged if there is an actual use case. Currently you can set the backend of a tensor to GGML_BACKEND_CPU and it will automatically copy the data back into RAM after the calculation. I don't want to have multiple pieces of code that do the same thing unless there is a good reason for it.

@LLukas22
Copy link

Oh i didnt know that, if i can simply set the backend to CPU that should be good enought for my use case. If that's the case we dont need an additional function to get the data back to the host.

@JohannesGaessler
Copy link
Collaborator

JohannesGaessler commented Jun 16, 2023

The logic is currently like this: if any out of src0, src1, or dst have the backend GGML_BACKEND_GPU, then the computation is always done on the GPU. If src0 or src1 have GGML_BACKEND_CPU then the data is copied to the GPU automatically (if they are tensors that actually modify the data, otherwise the current implementation is a little hacky). If dst has GGML_BACKEND_CPU then the data should always end up in RAM and it is not necessary to call ggml_cuda_assign_buffers since dst will not use those buffers.

@LLukas22
Copy link

Thank you, this information is exactly what I needed. However, I have a question concerning the behavior when both src0 and src1 tensors are on the GPU, but an operation that isn't CUDA-implemented is performed. Will the tensors be automatically relocated back to the CPU for the execution of the operation, or is this something that needs to be manually implemented?

I'm asking because I'm currently integrating your CUDA acceleration into rustformers\llm. I was able to implement llama without any issues, but I'm encountering some challenges with MPT, likely due to the alibi implementation that needs to run on the CPU.

I'm probably not the first and won't be the last person which is a bit confused on how the offloading works via the backends, what do you think should we create a short readme that describes the process? Or am I just stupid and the documentation already exists?

@JohannesGaessler
Copy link
Collaborator

I have a question concerning the behavior when both src0 and src1 tensors are on the GPU, but an operation that isn't CUDA-implemented is performed. Will the tensors be automatically relocated back to the CPU for the execution of the operation, or is this something that needs to be manually implemented?

If any out of src0, src1, or dst have GGML_BACKEND_GPU then ggml will attempt to run the GGML_OP described by dst->op via CUDA. If the op isn't implemented an assertion error occurs. However, the scenario that you are talking about can be avoided by not applying ggml_cuda_assign_buffers to src0 and src1. Then they will have GGML_BACKEND_CPU and their data will be copied to RAM directly after it's calculated. If dst also has GGML_BACKEND_CPU then the CPU implementation will be used (except for matrix multiplications with batch size >= 32 where copying to VRAM would be worthwhile).

I'm asking because I'm currently integrating your CUDA acceleration into rustformers\llm. I was able to implement llama without any issues, but I'm encountering some challenges with MPT, likely due to the alibi implementation that needs to run on the CPU.

Here's what should work: do not apply ggml_cuda_assign_buffers to the inputs of the alibi tensor so tehy will be in RAM. Then, apply ggml_cuda_assign_buffers to the following tensor that uses the alibi tensor as input.

what do you think should we create a short readme that describes the process? Or am I just stupid and the documentation already exists?

Documentation is always useful but it's a question of opportunity cost. Right now the CUDA code is changing relatively quickly so I don't want to spend time writing documentation that may become outdated after I notice that one of my earlier design decisions was bad and will need to be changed. I'm happy to answer specific questions though; I have a Mumble server that could be used for talking about it if desired.

@LLukas22
Copy link

If any out of src0, src1, or dst have GGML_BACKEND_GPU then ggml will attempt to run the GGML_OP described by dst->op via CUDA. If the op isn't implemented an assertion error occurs. However, the scenario that you are talking about can be avoided by not applying ggml_cuda_assign_buffers to src0 and src1. Then they will have GGML_BACKEND_CPU and their data will be copied to RAM directly after it's calculated. If dst also has GGML_BACKEND_CPU then the CPU implementation will be used (except for matrix multiplications with batch size >= 32 where copying to VRAM would be worthwhile).

Yup, that was the exeption i got and i already suspected the aliby function, i'll give it another try tomorrow. Could you link me to your Mumble server? If i have additional questions i'll ask them there.

@JohannesGaessler
Copy link
Collaborator

Either write an email to the address on my Github page or add your email address to your page and I'll send you the address and password for the mumble server.

@howard0su
Copy link
Collaborator Author

here is a repo:

.\build\bin\debug\embedding.exe -m ..\vicuna\13B\ggml-vicuna-13b-4bit-new2.bin -ngl 50 -p "what is cuda?"
main: build = 682 (602c748)
main: seed  = 1686921114
ggml_init_cublas: found 1 CUDA devices:
  Device 0: Tesla P100-PCIE-16GB
llama.cpp: loading model from ..\vicuna\13B\ggml-vicuna-13b-4bit-new2.bin
llama_model_load_internal: format     = ggjt v3 (latest)
llama_model_load_internal: n_vocab    = 32001
llama_model_load_internal: n_ctx      = 512
llama_model_load_internal: n_embd     = 5120
llama_model_load_internal: n_mult     = 256
llama_model_load_internal: n_head     = 40
llama_model_load_internal: n_layer    = 40
llama_model_load_internal: n_rot      = 128
llama_model_load_internal: ftype      = 2 (mostly Q4_0)
llama_model_load_internal: n_ff       = 13824
llama_model_load_internal: n_parts    = 1
llama_model_load_internal: model size = 13B
llama_model_load_internal: ggml ctx size =    0.09 MB
llama_model_load_internal: using CUDA for GPU acceleration
llama_model_load_internal: mem required  = 2135.98 MB (+ 1608.00 MB per state)
llama_model_load_internal: allocating batch_size x 1 MB = 512 MB VRAM for the scratch buffer
llama_model_load_internal: offloading 40 repeating layers to GPU
llama_model_load_internal: offloading non-repeating layers to GPU
llama_model_load_internal: offloading v cache to GPU
llama_model_load_internal: offloading k cache to GPU
llama_model_load_internal: offloaded 43/43 layers to GPU
llama_model_load_internal: total VRAM used: 9016 MB
....................................................................................................
llama_init_from_file: kv self size  =  400.00 MB

system_info: n_threads = 8 / 16 | AVX = 1 | AVX2 = 0 | AVX512 = 0 | AVX512_VBMI = 0 | AVX512_VNNI = 0 | FMA = 0 | NEON = 0 | ARM_FMA = 0 | F16C = 0 | FP16_VA = 0 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 0 | VSX = 0 |
0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 

@JohannesGaessler
Copy link
Collaborator

Okay, it seems I misunderstood the point of this PR. I can confirm that the embeddings binary does not work correctly when the non-repeating layers are offloaded. However, I think a more elegant solution would be to disable the call to offload_func_nr on line 1657 of llama.cpp on master. That will automatically copy the data to RAM and make it available.

@JohannesGaessler
Copy link
Collaborator

I made another PR #1891 that should also work as a fix but with an extremely simple change.

@howard0su
Copy link
Collaborator Author

sure. your fix is much simpler.

I am thinking how we can make the samplers running inside CUDA and keep everything stay in GPU.

@JohannesGaessler
Copy link
Collaborator

Are you sure that would actually be faster? Copying small amounts of data between CPU and GPU takes a few nanoseconds. Doing that once or twice per token is not going to make a meaningful difference. So the question would then be whether sampling is suitable for GPU acceleration. If it isn't then I suspect a GPU implementation will be slower than what is currently on master.

@JohannesGaessler
Copy link
Collaborator

In any case, I think the proper way to do what you want to do would be to add a backend like GGML_BACKEND_CPU_GPU or GGML_BACKEND_MIRRORED and extend the logic in ggml_cuda_op to also copy the data back to the host if dst has that backend. You would then only need to set the backend of embeddings after calling ggml_cuda_assign_buffers and the same logic could be easily applied to other tensors if necessary.

@howard0su
Copy link
Collaborator Author

If it isn't then I suspect a GPU implementation will be slower than what is currently on master.

This needs more tests. My suspicion is that the sync between GPU and CPU in order to do the sampling is slow. Of course, I don't have data to prove. will do some tests and get back.

@howard0su
Copy link
Collaborator Author

Not needed anymore

@howard0su howard0su closed this Jun 18, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants