Skip to content

Commit bcb31eb

Browse files
committed
cuBLAS: dequantize simultaneously while copying memory
1 parent 0b2da20 commit bcb31eb

File tree

3 files changed

+53
-41
lines changed

3 files changed

+53
-41
lines changed

ggml-cuda.cu

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,27 @@ void dequantize_row_q8_0_cuda(const void * vx, float * y, int k, cudaStream_t st
264264
dequantize_block_q8_0<<<nb, 1, 0, stream>>>(vx, y);
265265
}
266266

267+
dequantize_row_q_cuda_t ggml_get_dequantize_row_q_cuda(ggml_type type) {
268+
switch (type) {
269+
case GGML_TYPE_Q4_0:
270+
return dequantize_row_q4_0_cuda;
271+
case GGML_TYPE_Q4_1:
272+
return dequantize_row_q4_1_cuda;
273+
case GGML_TYPE_Q4_2:
274+
return dequantize_row_q4_2_cuda;
275+
case GGML_TYPE_Q4_3:
276+
return dequantize_row_q4_3_cuda;
277+
case GGML_TYPE_Q5_0:
278+
return dequantize_row_q5_0_cuda;
279+
case GGML_TYPE_Q5_1:
280+
return dequantize_row_q5_1_cuda;
281+
case GGML_TYPE_Q8_0:
282+
return dequantize_row_q8_0_cuda;
283+
default:
284+
return nullptr;
285+
}
286+
}
287+
267288
// buffer pool for cuda
268289
#define MAX_CUDA_BUFFERS 16
269290

@@ -323,18 +344,22 @@ void ggml_cuda_pool_free(void * ptr, size_t size) {
323344
CUDA_CHECK(cudaFree(ptr));
324345
}
325346

326-
cublasHandle_t g_cublasH = NULL;
327-
cudaStream_t g_cudaStream = NULL;
347+
cublasHandle_t g_cublasH = nullptr;
348+
cudaStream_t g_cudaStream = nullptr;
349+
cudaStream_t g_cudaStream2 = nullptr;
350+
cudaEvent_t g_cudaEvent = nullptr;
328351

329-
void ggml_init_cublas(void) {
330-
if (g_cublasH == NULL) {
352+
void ggml_init_cublas() {
353+
if (g_cublasH == nullptr) {
331354
// create cublas handle, bind a stream
332355
CUBLAS_CHECK(cublasCreate(&g_cublasH));
333-
334356
CUDA_CHECK(cudaStreamCreateWithFlags(&g_cudaStream, cudaStreamNonBlocking));
335-
336357
CUBLAS_CHECK(cublasSetStream(g_cublasH, g_cudaStream));
337358

359+
// create additional stream and event for synchronization
360+
CUDA_CHECK(cudaStreamCreateWithFlags(&g_cudaStream2, cudaStreamNonBlocking));
361+
CUDA_CHECK(cudaEventCreateWithFlags(&g_cudaEvent, cudaEventDisableTiming));
362+
338363
// configure logging to stdout
339364
// CUBLAS_CHECK(cublasLoggerConfigure(1, 1, 0, NULL));
340365
}

ggml-cuda.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include <cublas_v2.h>
22
#include <cuda_runtime.h>
3+
#include "ggml.h"
34

45
#ifdef __cplusplus
56
extern "C" {
@@ -25,7 +26,9 @@ extern "C" {
2526
} while (0)
2627

2728
extern cublasHandle_t g_cublasH;
28-
extern cudaStream_t g_cudaStream;
29+
extern cudaStream_t g_cudaStream;
30+
extern cudaStream_t g_cudaStream2;
31+
extern cudaEvent_t g_cudaEvent;
2932

3033
void ggml_init_cublas(void);
3134
void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size);
@@ -39,6 +42,9 @@ void dequantize_row_q5_0_cuda(const void * vx, float * y, int k, cudaStream_t st
3942
void dequantize_row_q5_1_cuda(const void * vx, float * y, int k, cudaStream_t stream);
4043
void dequantize_row_q8_0_cuda(const void * vx, float * y, int k, cudaStream_t stream);
4144

45+
typedef void (*dequantize_row_q_cuda_t)(const void * x, float * y, int k, cudaStream_t stream);
46+
dequantize_row_q_cuda_t ggml_get_dequantize_row_q_cuda(enum ggml_type type);
47+
4248
#ifdef __cplusplus
4349
}
4450
#endif

ggml.c

Lines changed: 15 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -8200,7 +8200,7 @@ static void ggml_compute_forward_mul_mat_f32(
82008200
#if defined(GGML_USE_CUBLAS)
82018201
const float alpha = 1.0f;
82028202
const float beta = 0.0f;
8203-
const int x_ne = ne01 * ne10;
8203+
const int x_ne = ne01 * ne00;
82048204
const int y_ne = ne11 * ne10;
82058205
const int d_ne = ne11 * ne01;
82068206

@@ -8398,7 +8398,7 @@ static void ggml_compute_forward_mul_mat_f16_f32(
83988398

83998399
const float alpha = 1.0f;
84008400
const float beta = 0.0f;
8401-
const int x_ne = ne01 * ne10;
8401+
const int x_ne = ne01 * ne00;
84028402
const int y_ne = ne11 * ne10;
84038403
const int d_ne = ne11 * ne01;
84048404

@@ -8645,41 +8645,18 @@ static void ggml_compute_forward_mul_mat_q_f32(
86458645
#if defined(GGML_USE_CUBLAS)
86468646
const float alpha = 1.0f;
86478647
const float beta = 0.0f;
8648-
const int x_ne = ne01 * ne10;
8648+
const int x_ne = ne01 * ne00;
86498649
const int y_ne = ne11 * ne10;
86508650
const int d_ne = ne11 * ne01;
86518651

86528652
size_t x_size, y_size, d_size, q_size;
8653-
float *d_X = ggml_cuda_pool_malloc(sizeof(float) * x_ne, &x_size);
8654-
float *d_Y = ggml_cuda_pool_malloc(sizeof(float) * y_ne, &y_size);
8655-
float *d_D = ggml_cuda_pool_malloc(sizeof(float) * d_ne, &d_size);
8656-
float *d_Q = ggml_cuda_pool_malloc(GGML_TYPE_SIZE[type] * x_ne / GGML_BLCK_SIZE[type], &q_size);
8653+
float * d_X = ggml_cuda_pool_malloc(sizeof(float) * x_ne, &x_size);
8654+
float * d_Y = ggml_cuda_pool_malloc(sizeof(float) * y_ne, &y_size);
8655+
float * d_D = ggml_cuda_pool_malloc(sizeof(float) * d_ne, &d_size);
8656+
void * d_Q = ggml_cuda_pool_malloc(GGML_TYPE_SIZE[type] * x_ne / GGML_BLCK_SIZE[type], &q_size);
86578657

8658-
void (*dequantize_row_q_cuda)(const void * x, float * y, int k, cudaStream_t stream) = NULL;
8659-
if (type == GGML_TYPE_Q4_0) {
8660-
dequantize_row_q_cuda = dequantize_row_q4_0_cuda;
8661-
}
8662-
else if (type == GGML_TYPE_Q4_1) {
8663-
dequantize_row_q_cuda = dequantize_row_q4_1_cuda;
8664-
}
8665-
else if (type == GGML_TYPE_Q4_2) {
8666-
dequantize_row_q_cuda = dequantize_row_q4_2_cuda;
8667-
}
8668-
else if (type == GGML_TYPE_Q4_3) {
8669-
dequantize_row_q_cuda = dequantize_row_q4_3_cuda;
8670-
}
8671-
else if (type == GGML_TYPE_Q5_0) {
8672-
dequantize_row_q_cuda = dequantize_row_q5_0_cuda;
8673-
}
8674-
else if (type == GGML_TYPE_Q5_1) {
8675-
dequantize_row_q_cuda = dequantize_row_q5_1_cuda;
8676-
}
8677-
else if (type == GGML_TYPE_Q8_0) {
8678-
dequantize_row_q_cuda = dequantize_row_q8_0_cuda;
8679-
}
8680-
else {
8681-
GGML_ASSERT(false);
8682-
}
8658+
const dequantize_row_q_cuda_t dequantize_row_q_cuda = ggml_get_dequantize_row_q_cuda(type);
8659+
GGML_ASSERT(dequantize_row_q_cuda != NULL);
86838660
#else
86848661
float * const wdata = params->wdata;
86858662
dequantize_row_q_t const dequantize_row_q = quantize_fns[type].dequantize_row_q;
@@ -8695,10 +8672,11 @@ static void ggml_compute_forward_mul_mat_q_f32(
86958672
// copy and dequantize on device
86968673
CUDA_CHECK(
86978674
cudaMemcpyAsync(d_Q, (char *) src0->data + i03*nb03 + i02*nb02,
8698-
GGML_TYPE_SIZE[type] * x_ne / GGML_BLCK_SIZE[type], cudaMemcpyHostToDevice, g_cudaStream));
8675+
GGML_TYPE_SIZE[type] * x_ne / GGML_BLCK_SIZE[type], cudaMemcpyHostToDevice, g_cudaStream2));
86998676

8700-
dequantize_row_q_cuda(d_Q, d_X, ne01 * ne00, g_cudaStream);
8677+
dequantize_row_q_cuda(d_Q, d_X, x_ne, g_cudaStream2);
87018678
CUDA_CHECK(cudaGetLastError());
8679+
CUDA_CHECK(cudaEventRecord(g_cudaEvent, g_cudaStream2));
87028680
#else
87038681
{
87048682
size_t id = 0;
@@ -8715,6 +8693,9 @@ static void ggml_compute_forward_mul_mat_q_f32(
87158693
// copy data to device
87168694
CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(float) * y_ne, cudaMemcpyHostToDevice, g_cudaStream));
87178695

8696+
// wait for dequantization
8697+
CUDA_CHECK(cudaStreamWaitEvent(g_cudaStream, g_cudaEvent, 0));
8698+
87188699
// compute
87198700
CUBLAS_CHECK(
87208701
cublasSgemm(g_cublasH, CUBLAS_OP_T, CUBLAS_OP_N,

0 commit comments

Comments
 (0)