Skip to content

Commit 637be12

Browse files
CUDA kernel for q4_0 dequant. + mat. vec. mult.
1 parent fb62f92 commit 637be12

File tree

8 files changed

+175
-26
lines changed

8 files changed

+175
-26
lines changed

examples/common.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,12 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
277277
params.use_color = true;
278278
} else if (arg == "--mlock") {
279279
params.use_mlock = true;
280+
} else if (arg == "--gpu_layers") {
281+
if (++i >= argc) {
282+
invalid_param = true;
283+
break;
284+
}
285+
params.gpu_layers = std::stoi(argv[i]);
280286
} else if (arg == "--no-mmap") {
281287
params.use_mmap = false;
282288
} else if (arg == "--mtest") {
@@ -421,6 +427,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
421427
if (llama_mmap_supported()) {
422428
fprintf(stderr, " --no-mmap do not memory-map model (slower load but may reduce pageouts if not using mlock)\n");
423429
}
430+
fprintf(stderr, " --gpu_layers number of layers to store in VRAM\n");
424431
fprintf(stderr, " --mtest compute maximum memory usage\n");
425432
fprintf(stderr, " --verbose-prompt print prompt before generation\n");
426433
fprintf(stderr, " --lora FNAME apply LoRA adapter (implies --no-mmap)\n");
@@ -469,6 +476,7 @@ struct llama_context * llama_init_from_gpt_params(const gpt_params & params) {
469476
lparams.f16_kv = params.memory_f16;
470477
lparams.use_mmap = params.use_mmap;
471478
lparams.use_mlock = params.use_mlock;
479+
lparams.gpu_layers = params.gpu_layers;
472480
lparams.logits_all = params.perplexity;
473481
lparams.embedding = params.embedding;
474482

examples/common.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ struct gpt_params {
6969
bool perplexity = false; // compute perplexity over the prompt
7070
bool use_mmap = true; // use mmap for faster loads
7171
bool use_mlock = false; // use mlock to keep model in memory
72+
int gpu_layers = 0; // number of layers to store in VRAM
7273
bool mem_test = false; // compute maximum memory usage
7374
bool verbose_prompt = false; // print prompt tokens before generation
7475
};

ggml-cuda.cu

Lines changed: 135 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,52 @@ static __global__ void dequantize_block_q8_0(const void * vx, float * y) {
173173
}
174174
}
175175

176+
template <int block_size> static __global__ void dequantize_mul_mat_q4_0(const void * vx, const float * y, float * dst, const int ncols) {
177+
const block_q4_0 * x = (const block_q4_0 *) vx;
178+
const int qk = QK4_0;
179+
180+
const int row = blockIdx.x;
181+
const int tid = threadIdx.x;
182+
183+
__shared__ float tmp[block_size]; // separate sum for each thread
184+
tmp[tid] = 0;
185+
186+
for (int i = 0; i < ncols/block_size; i += 2) {
187+
const int col = i*block_size + 2*tid;
188+
const int ib = (row*ncols + col)/qk; // block index
189+
const int iqs = (col%qk)/2; // quant index
190+
const int iybs = col - col%qk; // y block start index
191+
192+
// dequantize
193+
const float d = x[ib].d;
194+
195+
const uint8_t * pp = x[ib].qs;
196+
197+
const uint8_t vui = pp[iqs];
198+
199+
const int8_t vi0 = vui & 0xF;
200+
const int8_t vi1 = vui >> 4;
201+
202+
const float v0 = (vi0 - 8)*d;
203+
const float v1 = (vi1 - 8)*d;
204+
205+
// matrix multiplication
206+
tmp[tid] += v0 * y[iybs + iqs + 0];
207+
tmp[tid] += v1 * y[iybs + iqs + qk/2];
208+
}
209+
210+
// sum up partial sums and write back result
211+
for (int s=block_size/2; s>0; s>>=1) {
212+
if (tid < s) {
213+
tmp[tid] += tmp[tid + s];
214+
}
215+
__syncthreads();
216+
}
217+
if (tid == 0) {
218+
dst[row] = tmp[0];
219+
}
220+
}
221+
176222
static void dequantize_row_q4_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
177223
const int nb = k / QK4_0;
178224
dequantize_block_q4_0<<<nb, 1, 0, stream>>>(vx, y);
@@ -198,6 +244,23 @@ static void dequantize_row_q8_0_cuda(const void * vx, float * y, int k, cudaStre
198244
dequantize_block_q8_0<<<nb, 1, 0, stream>>>(vx, y);
199245
}
200246

247+
static void dequantize_mul_mat_q4_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
248+
// static int block_size = -1;
249+
// if (block_size == -1) {
250+
// int min_grid_size, max_block_size = 1;
251+
// CUDA_CHECK(cudaOccupancyMaxPotentialBlockSize(&min_grid_size, &max_block_size, dequantize_mul_mat_q4_0<256>, 0, 0));
252+
// max_block_size = min(max_block_size, GGML_CUDA_MAX_BLOCK_SIZE);
253+
// block_size = 1;
254+
// while (block_size*2 <= max_block_size && block_size*2 % ncols == 0) {
255+
// block_size *= 2;
256+
// }
257+
// }
258+
// dequantize_mul_mat_q4_0<<<nrows, block_size, 0, stream>>>(vx, y, dst, ncols);
259+
const int block_size = 32;
260+
GGML_ASSERT(ncols % block_size == 0);
261+
dequantize_mul_mat_q4_0<block_size><<<nrows, block_size, 0, stream>>>(vx, y, dst, ncols);
262+
}
263+
201264
// TODO: optimize
202265
static __global__ void convert_fp16_to_fp32(const void * vx, float * y) {
203266
const half * x = (const half *) vx;
@@ -231,7 +294,7 @@ static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
231294
}
232295

233296
// buffer pool for cuda
234-
#define MAX_CUDA_BUFFERS 16
297+
#define MAX_CUDA_BUFFERS 256
235298

236299
struct scoped_spin_lock {
237300
std::atomic_flag& lock;
@@ -538,7 +601,10 @@ static void ggml_cuda_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor
538601
const size_t q_sz = ggml_type_size(type) * x_ne / ggml_blck_size(type);
539602

540603
size_t x_size, y_size, d_size, q_size;
541-
float * d_X = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * x_ne, &x_size);
604+
float * d_X;
605+
if (ne11 > 1) {
606+
d_X = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * x_ne, &x_size);
607+
}
542608
float * d_Y = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * y_ne, &y_size);
543609
float * d_D = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * d_ne, &d_size);
544610
char * d_Q = (char *) ggml_cuda_pool_malloc(n_mm * q_sz, &q_size);
@@ -553,31 +619,54 @@ static void ggml_cuda_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor
553619
cudaStream_t cudaStream2 = g_cudaStreams2[i % GGML_CUDA_MAX_STREAMS];
554620
cudaEvent_t cudaEvent = g_cudaEvents[i % GGML_CUDA_MAX_EVENTS];
555621

556-
float * c_X = d_X + i * x_ne;
557622
float * c_Y = d_Y + i * y_ne;
558623
float * c_D = d_D + i * d_ne;
559624
char * c_Q = d_Q + i * q_sz;
560625

561-
// copy src0 and convert to fp32 on device
562-
CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_Q, src0, i03, i02, cudaStream2));
563-
to_fp32_cuda(c_Q, c_X, x_ne, cudaStream2);
564-
CUDA_CHECK(cudaGetLastError());
565-
CUDA_CHECK(cudaEventRecord(cudaEvent, cudaStream2));
626+
// copy src0 to device if necessary
627+
if (src0->backend == GGML_BACKEND_CPU) {
628+
CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_Q, src0, i03, i02, cudaStream2));
629+
} else if (src0->backend == GGML_BACKEND_CUDA) {
630+
c_Q = ((char *) src0->data) + i * q_sz;
631+
} else {
632+
GGML_ASSERT(false);
633+
}
634+
if (ne11 == 1) {
635+
CUDA_CHECK(cudaEventRecord(cudaEvent, cudaStream2));
566636

567-
// copy src1 to device
568-
CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_Y, src1, i03, i02, cudaStream));
637+
// copy src1 to device
638+
CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_Y, src1, i03, i02, cudaStream));
569639

570-
// wait for conversion
571-
CUDA_CHECK(cudaStreamWaitEvent(cudaStream, cudaEvent, 0));
640+
// wait for data
641+
CUDA_CHECK(cudaStreamWaitEvent(cudaStream, cudaEvent, 0));
572642

573-
// compute
574-
CUBLAS_CHECK(cublasSetStream(g_cublasH, cudaStream));
575-
CUBLAS_CHECK(
576-
cublasSgemm(g_cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
577-
ne01, ne11, ne10,
578-
&alpha, c_X, ne00,
579-
c_Y, ne10,
580-
&beta, c_D, ne01));
643+
// compute
644+
dequantize_mul_mat_q4_0_cuda(c_Q, c_Y, c_D, ne00, ne01, cudaStream);
645+
CUDA_CHECK(cudaGetLastError());
646+
647+
} else {
648+
float * c_X = d_X + i * x_ne;
649+
650+
// convert src0 to fp32 on device
651+
to_fp32_cuda(c_Q, c_X, x_ne, cudaStream2);
652+
CUDA_CHECK(cudaGetLastError());
653+
CUDA_CHECK(cudaEventRecord(cudaEvent, cudaStream2));
654+
655+
// copy src1 to device
656+
CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_Y, src1, i03, i02, cudaStream));
657+
658+
// wait for conversion
659+
CUDA_CHECK(cudaStreamWaitEvent(cudaStream, cudaEvent, 0));
660+
661+
// compute
662+
CUBLAS_CHECK(cublasSetStream(g_cublasH, cudaStream));
663+
CUBLAS_CHECK(
664+
cublasSgemm(g_cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
665+
ne01, ne11, ne10,
666+
&alpha, c_X, ne00,
667+
c_Y, ne10,
668+
&beta, c_D, ne01));
669+
}
581670

582671
// copy dst to host
583672
float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
@@ -586,7 +675,9 @@ static void ggml_cuda_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor
586675
}
587676

588677
CUDA_CHECK(cudaDeviceSynchronize());
589-
ggml_cuda_pool_free(d_X, x_size);
678+
if (ne11 > 1) {
679+
ggml_cuda_pool_free(d_X, x_size);
680+
}
590681
ggml_cuda_pool_free(d_Y, y_size);
591682
ggml_cuda_pool_free(d_D, d_size);
592683
ggml_cuda_pool_free(d_Q, q_size);
@@ -602,8 +693,7 @@ bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_te
602693
if ((src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) &&
603694
src1->type == GGML_TYPE_F32 &&
604695
dst->type == GGML_TYPE_F32 &&
605-
(ne0 >= 32 && ne1 >= 32 && ne10 >= 32)) {
606-
696+
((ne0 >= 32 && ne1 >= 32 && ne10 >= 32) || src0->backend == GGML_BACKEND_CUDA)) {
607697
return true;
608698
}
609699

@@ -655,3 +745,25 @@ size_t ggml_cuda_mul_mat_get_wsize(const struct ggml_tensor * src0, const struct
655745
return 0;
656746
}
657747
}
748+
749+
void ggml_cuda_transform_tensor(ggml_tensor * tensor) {
750+
const int64_t ne0 = tensor->ne[0];
751+
const int64_t ne1 = tensor->ne[1];
752+
const int64_t ne2 = tensor->ne[2];
753+
const int64_t ne3 = tensor->ne[3];
754+
755+
const ggml_type type = tensor->type;
756+
const size_t q_sz = ggml_type_size(type) * ne0 * ne1 * ne2 * ne3 / ggml_blck_size(type);
757+
758+
size_t q_size;
759+
char * d_Q = (char *) ggml_cuda_pool_malloc(q_sz, &q_size);
760+
761+
cudaStream_t cudaStream2 = g_cudaStreams2[0];
762+
763+
// copy tensor to device
764+
CUDA_CHECK(ggml_cuda_h2d_tensor_2d(d_Q, tensor, 0, 0, cudaStream2));
765+
CUDA_CHECK(cudaDeviceSynchronize());
766+
767+
tensor->data = d_Q;
768+
tensor->backend = GGML_BACKEND_CUDA;
769+
}

ggml-cuda.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ void ggml_cuda_mul_mat(const struct ggml_tensor * src0, const struct ggml_tens
1414
void * ggml_cuda_host_malloc(size_t size);
1515
void ggml_cuda_host_free(void * ptr);
1616

17+
void ggml_cuda_transform_tensor(struct ggml_tensor * tensor);
18+
1719
#ifdef __cplusplus
1820
}
1921
#endif

ggml.c

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3702,6 +3702,7 @@ struct ggml_tensor * ggml_new_tensor_impl(
37023702

37033703
*result = (struct ggml_tensor) {
37043704
/*.type =*/ type,
3705+
/*.backend =*/ GGML_BACKEND_CPU,
37053706
/*.n_dims =*/ n_dims,
37063707
/*.ne =*/ { 1, 1, 1, 1 },
37073708
/*.nb =*/ { 0, 0, 0, 0 },

ggml.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,11 @@ extern "C" {
243243
GGML_TYPE_COUNT,
244244
};
245245

246+
enum ggml_backend {
247+
GGML_BACKEND_CPU = 0,
248+
GGML_BACKEND_CUDA = 1,
249+
};
250+
246251
// model file types
247252
enum ggml_ftype {
248253
GGML_FTYPE_UNKNOWN = -1,
@@ -322,6 +327,7 @@ extern "C" {
322327
// n-dimensional tensor
323328
struct ggml_tensor {
324329
enum ggml_type type;
330+
enum ggml_backend backend;
325331

326332
int n_dims;
327333
int64_t ne[GGML_MAX_DIMS]; // number of elements
@@ -352,7 +358,7 @@ extern "C" {
352358

353359
char name[32];
354360

355-
char padding[8]; // TODO: remove and add padding to name?
361+
char padding[9]; // TODO: remove and add padding to name?
356362
};
357363

358364
// computation graph

llama.cpp

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@
99
#include "llama.h"
1010

1111
#include "ggml.h"
12+
#ifdef GGML_USE_CUBLAS
13+
#include "ggml-cuda.h"
14+
#endif
1215

1316
#include <array>
1417
#include <ctime>
@@ -816,6 +819,7 @@ struct llama_context_params llama_context_default_params() {
816819
/*.vocab_only =*/ false,
817820
/*.use_mmap =*/ true,
818821
/*.use_mlock =*/ false,
822+
/*.gpu_layers =*/ 0,
819823
/*.embedding =*/ false,
820824
/*.progress_callback =*/ nullptr,
821825
/*.progress_callback_user_data =*/ nullptr,
@@ -879,6 +883,7 @@ static void llama_model_load_internal(
879883
ggml_type memory_type,
880884
bool use_mmap,
881885
bool use_mlock,
886+
int gpu_layers,
882887
bool vocab_only,
883888
llama_progress_callback progress_callback,
884889
void * progress_callback_user_data) {
@@ -1021,6 +1026,18 @@ static void llama_model_load_internal(
10211026
ml->load_all_data(progress_callback, progress_callback_user_data, use_mlock ? &lctx.model.mlock_mmap : NULL);
10221027

10231028
model.mapping = std::move(ml->mapping);
1029+
#ifdef GGML_USE_CUBLAS
1030+
for (int i = 0; i < std::min(gpu_layers, int(hparams.n_layer)); ++i) {
1031+
auto & layer = model.layers[i];
1032+
ggml_cuda_transform_tensor(layer.wq);
1033+
ggml_cuda_transform_tensor(layer.wk);
1034+
ggml_cuda_transform_tensor(layer.wv);
1035+
ggml_cuda_transform_tensor(layer.wo);
1036+
ggml_cuda_transform_tensor(layer.w1);
1037+
ggml_cuda_transform_tensor(layer.w2);
1038+
ggml_cuda_transform_tensor(layer.w3);
1039+
}
1040+
#endif
10241041

10251042
// loading time will be recalculate after the first eval, so
10261043
// we take page faults deferred by mmap() into consideration
@@ -1034,11 +1051,12 @@ static bool llama_model_load(
10341051
ggml_type memory_type,
10351052
bool use_mmap,
10361053
bool use_mlock,
1054+
int gpu_layers,
10371055
bool vocab_only,
10381056
llama_progress_callback progress_callback,
10391057
void *progress_callback_user_data) {
10401058
try {
1041-
llama_model_load_internal(fname, lctx, n_ctx, memory_type, use_mmap, use_mlock,
1059+
llama_model_load_internal(fname, lctx, n_ctx, memory_type, use_mmap, use_mlock, gpu_layers,
10421060
vocab_only, progress_callback, progress_callback_user_data);
10431061
return true;
10441062
} catch (const std::string & err) {
@@ -2097,7 +2115,7 @@ struct llama_context * llama_init_from_file(
20972115
ggml_type memory_type = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32;
20982116

20992117
if (!llama_model_load(path_model, *ctx, params.n_ctx, memory_type,
2100-
params.use_mmap, params.use_mlock, params.vocab_only,
2118+
params.use_mmap, params.use_mlock, params.gpu_layers, params.vocab_only,
21012119
params.progress_callback, params.progress_callback_user_data)) {
21022120
fprintf(stderr, "%s: failed to load model\n", __func__);
21032121
llama_free(ctx);

llama.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ extern "C" {
6363
bool vocab_only; // only load the vocabulary, no weights
6464
bool use_mmap; // use mmap if possible
6565
bool use_mlock; // force system to keep model in RAM
66+
int gpu_layers; // number of layers to store in VRAM
6667
bool embedding; // embedding mode only
6768

6869
// called with a progress value between 0 and 1, pass NULL to disable

0 commit comments

Comments
 (0)