Skip to content

Commit 86b170d

Browse files
committed
cuBLAS: dequantize simultaneously while copying memory
1 parent 92a6e13 commit 86b170d

File tree

3 files changed

+53
-41
lines changed

3 files changed

+53
-41
lines changed

ggml-cuda.cu

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,27 @@ void dequantize_row_q8_0_cuda(const void * vx, float * y, int k, cudaStream_t st
264264
dequantize_block_q8_0<<<nb, 1, 0, stream>>>(vx, y);
265265
}
266266

267+
dequantize_row_q_cuda_t ggml_get_dequantize_row_q_cuda(ggml_type type) {
268+
switch (type) {
269+
case GGML_TYPE_Q4_0:
270+
return dequantize_row_q4_0_cuda;
271+
case GGML_TYPE_Q4_1:
272+
return dequantize_row_q4_1_cuda;
273+
case GGML_TYPE_Q4_2:
274+
return dequantize_row_q4_2_cuda;
275+
case GGML_TYPE_Q4_3:
276+
return dequantize_row_q4_3_cuda;
277+
case GGML_TYPE_Q5_0:
278+
return dequantize_row_q5_0_cuda;
279+
case GGML_TYPE_Q5_1:
280+
return dequantize_row_q5_1_cuda;
281+
case GGML_TYPE_Q8_0:
282+
return dequantize_row_q8_0_cuda;
283+
default:
284+
return nullptr;
285+
}
286+
}
287+
267288
// buffer pool for cuda
268289
#define MAX_CUDA_BUFFERS 16
269290

@@ -323,18 +344,22 @@ void ggml_cuda_pool_free(void * ptr, size_t size) {
323344
CUDA_CHECK(cudaFree(ptr));
324345
}
325346

326-
cublasHandle_t g_cublasH = NULL;
327-
cudaStream_t g_cudaStream = NULL;
347+
cublasHandle_t g_cublasH = nullptr;
348+
cudaStream_t g_cudaStream = nullptr;
349+
cudaStream_t g_cudaStream2 = nullptr;
350+
cudaEvent_t g_cudaEvent = nullptr;
328351

329-
void ggml_init_cublas(void) {
330-
if (g_cublasH == NULL) {
352+
void ggml_init_cublas() {
353+
if (g_cublasH == nullptr) {
331354
// create cublas handle, bind a stream
332355
CUBLAS_CHECK(cublasCreate(&g_cublasH));
333-
334356
CUDA_CHECK(cudaStreamCreateWithFlags(&g_cudaStream, cudaStreamNonBlocking));
335-
336357
CUBLAS_CHECK(cublasSetStream(g_cublasH, g_cudaStream));
337358

359+
// create additional stream and event for synchronization
360+
CUDA_CHECK(cudaStreamCreateWithFlags(&g_cudaStream2, cudaStreamNonBlocking));
361+
CUDA_CHECK(cudaEventCreateWithFlags(&g_cudaEvent, cudaEventDisableTiming));
362+
338363
// configure logging to stdout
339364
// CUBLAS_CHECK(cublasLoggerConfigure(1, 1, 0, NULL));
340365
}

ggml-cuda.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include <cublas_v2.h>
22
#include <cuda_runtime.h>
3+
#include "ggml.h"
34

45
#ifdef __cplusplus
56
extern "C" {
@@ -25,7 +26,9 @@ extern "C" {
2526
} while (0)
2627

2728
extern cublasHandle_t g_cublasH;
28-
extern cudaStream_t g_cudaStream;
29+
extern cudaStream_t g_cudaStream;
30+
extern cudaStream_t g_cudaStream2;
31+
extern cudaEvent_t g_cudaEvent;
2932

3033
void ggml_init_cublas(void);
3134
void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size);
@@ -39,6 +42,9 @@ void dequantize_row_q5_0_cuda(const void * vx, float * y, int k, cudaStream_t st
3942
void dequantize_row_q5_1_cuda(const void * vx, float * y, int k, cudaStream_t stream);
4043
void dequantize_row_q8_0_cuda(const void * vx, float * y, int k, cudaStream_t stream);
4144

45+
typedef void (*dequantize_row_q_cuda_t)(const void * x, float * y, int k, cudaStream_t stream);
46+
dequantize_row_q_cuda_t ggml_get_dequantize_row_q_cuda(enum ggml_type type);
47+
4248
#ifdef __cplusplus
4349
}
4450
#endif

ggml.c

Lines changed: 15 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -8218,7 +8218,7 @@ static void ggml_compute_forward_mul_mat_f32(
82188218
#if defined(GGML_USE_CUBLAS)
82198219
const float alpha = 1.0f;
82208220
const float beta = 0.0f;
8221-
const int x_ne = ne01 * ne10;
8221+
const int x_ne = ne01 * ne00;
82228222
const int y_ne = ne11 * ne10;
82238223
const int d_ne = ne11 * ne01;
82248224

@@ -8416,7 +8416,7 @@ static void ggml_compute_forward_mul_mat_f16_f32(
84168416

84178417
const float alpha = 1.0f;
84188418
const float beta = 0.0f;
8419-
const int x_ne = ne01 * ne10;
8419+
const int x_ne = ne01 * ne00;
84208420
const int y_ne = ne11 * ne10;
84218421
const int d_ne = ne11 * ne01;
84228422

@@ -8663,41 +8663,18 @@ static void ggml_compute_forward_mul_mat_q_f32(
86638663
#if defined(GGML_USE_CUBLAS)
86648664
const float alpha = 1.0f;
86658665
const float beta = 0.0f;
8666-
const int x_ne = ne01 * ne10;
8666+
const int x_ne = ne01 * ne00;
86678667
const int y_ne = ne11 * ne10;
86688668
const int d_ne = ne11 * ne01;
86698669

86708670
size_t x_size, y_size, d_size, q_size;
8671-
float *d_X = ggml_cuda_pool_malloc(sizeof(float) * x_ne, &x_size);
8672-
float *d_Y = ggml_cuda_pool_malloc(sizeof(float) * y_ne, &y_size);
8673-
float *d_D = ggml_cuda_pool_malloc(sizeof(float) * d_ne, &d_size);
8674-
float *d_Q = ggml_cuda_pool_malloc(GGML_TYPE_SIZE[type] * x_ne / GGML_BLCK_SIZE[type], &q_size);
8671+
float * d_X = ggml_cuda_pool_malloc(sizeof(float) * x_ne, &x_size);
8672+
float * d_Y = ggml_cuda_pool_malloc(sizeof(float) * y_ne, &y_size);
8673+
float * d_D = ggml_cuda_pool_malloc(sizeof(float) * d_ne, &d_size);
8674+
void * d_Q = ggml_cuda_pool_malloc(GGML_TYPE_SIZE[type] * x_ne / GGML_BLCK_SIZE[type], &q_size);
86758675

8676-
void (*dequantize_row_q_cuda)(const void * x, float * y, int k, cudaStream_t stream) = NULL;
8677-
if (type == GGML_TYPE_Q4_0) {
8678-
dequantize_row_q_cuda = dequantize_row_q4_0_cuda;
8679-
}
8680-
else if (type == GGML_TYPE_Q4_1) {
8681-
dequantize_row_q_cuda = dequantize_row_q4_1_cuda;
8682-
}
8683-
else if (type == GGML_TYPE_Q4_2) {
8684-
dequantize_row_q_cuda = dequantize_row_q4_2_cuda;
8685-
}
8686-
else if (type == GGML_TYPE_Q4_3) {
8687-
dequantize_row_q_cuda = dequantize_row_q4_3_cuda;
8688-
}
8689-
else if (type == GGML_TYPE_Q5_0) {
8690-
dequantize_row_q_cuda = dequantize_row_q5_0_cuda;
8691-
}
8692-
else if (type == GGML_TYPE_Q5_1) {
8693-
dequantize_row_q_cuda = dequantize_row_q5_1_cuda;
8694-
}
8695-
else if (type == GGML_TYPE_Q8_0) {
8696-
dequantize_row_q_cuda = dequantize_row_q8_0_cuda;
8697-
}
8698-
else {
8699-
GGML_ASSERT(false);
8700-
}
8676+
const dequantize_row_q_cuda_t dequantize_row_q_cuda = ggml_get_dequantize_row_q_cuda(type);
8677+
GGML_ASSERT(dequantize_row_q_cuda != NULL);
87018678
#else
87028679
float * const wdata = params->wdata;
87038680
dequantize_row_q_t const dequantize_row_q = quantize_fns[type].dequantize_row_q;
@@ -8713,10 +8690,11 @@ static void ggml_compute_forward_mul_mat_q_f32(
87138690
// copy and dequantize on device
87148691
CUDA_CHECK(
87158692
cudaMemcpyAsync(d_Q, (char *) src0->data + i03*nb03 + i02*nb02,
8716-
GGML_TYPE_SIZE[type] * x_ne / GGML_BLCK_SIZE[type], cudaMemcpyHostToDevice, g_cudaStream));
8693+
GGML_TYPE_SIZE[type] * x_ne / GGML_BLCK_SIZE[type], cudaMemcpyHostToDevice, g_cudaStream2));
87178694

8718-
dequantize_row_q_cuda(d_Q, d_X, ne01 * ne00, g_cudaStream);
8695+
dequantize_row_q_cuda(d_Q, d_X, x_ne, g_cudaStream2);
87198696
CUDA_CHECK(cudaGetLastError());
8697+
CUDA_CHECK(cudaEventRecord(g_cudaEvent, g_cudaStream2));
87208698
#else
87218699
{
87228700
size_t id = 0;
@@ -8733,6 +8711,9 @@ static void ggml_compute_forward_mul_mat_q_f32(
87338711
// copy data to device
87348712
CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(float) * y_ne, cudaMemcpyHostToDevice, g_cudaStream));
87358713

8714+
// wait for dequantization
8715+
CUDA_CHECK(cudaStreamWaitEvent(g_cudaStream, g_cudaEvent, 0));
8716+
87368717
// compute
87378718
CUBLAS_CHECK(
87388719
cublasSgemm(g_cublasH, CUBLAS_OP_T, CUBLAS_OP_N,

0 commit comments

Comments
 (0)