Skip to content

Commit d3fd04e

Browse files
committed
cuBLAS: dequantize simultaneously while copying memory
1 parent b1ee8f5 commit d3fd04e

File tree

3 files changed

+49
-38
lines changed

3 files changed

+49
-38
lines changed

ggml-cuda.cu

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,25 @@ void dequantize_row_q8_0_cuda(const void * vx, float * y, int k, cudaStream_t st
227227
dequantize_block_q8_0<<<nb, 1, 0, stream>>>(vx, y);
228228
}
229229

230+
dequantize_row_q_cuda_t ggml_get_dequantize_row_q_cuda(ggml_type type) {
231+
switch (type) {
232+
case GGML_TYPE_Q4_0:
233+
return dequantize_row_q4_0_cuda;
234+
case GGML_TYPE_Q4_1:
235+
return dequantize_row_q4_1_cuda;
236+
case GGML_TYPE_Q4_2:
237+
return dequantize_row_q4_2_cuda;
238+
case GGML_TYPE_Q5_0:
239+
return dequantize_row_q5_0_cuda;
240+
case GGML_TYPE_Q5_1:
241+
return dequantize_row_q5_1_cuda;
242+
case GGML_TYPE_Q8_0:
243+
return dequantize_row_q8_0_cuda;
244+
default:
245+
return nullptr;
246+
}
247+
}
248+
230249
// buffer pool for cuda
231250
#define MAX_CUDA_BUFFERS 16
232251

@@ -286,18 +305,22 @@ void ggml_cuda_pool_free(void * ptr, size_t size) {
286305
CUDA_CHECK(cudaFree(ptr));
287306
}
288307

289-
cublasHandle_t g_cublasH = NULL;
290-
cudaStream_t g_cudaStream = NULL;
308+
cublasHandle_t g_cublasH = nullptr;
309+
cudaStream_t g_cudaStream = nullptr;
310+
cudaStream_t g_cudaStream2 = nullptr;
311+
cudaEvent_t g_cudaEvent = nullptr;
291312

292-
void ggml_init_cublas(void) {
293-
if (g_cublasH == NULL) {
313+
void ggml_init_cublas() {
314+
if (g_cublasH == nullptr) {
294315
// create cublas handle, bind a stream
295316
CUBLAS_CHECK(cublasCreate(&g_cublasH));
296-
297317
CUDA_CHECK(cudaStreamCreateWithFlags(&g_cudaStream, cudaStreamNonBlocking));
298-
299318
CUBLAS_CHECK(cublasSetStream(g_cublasH, g_cudaStream));
300319

320+
// create additional stream and event for synchronization
321+
CUDA_CHECK(cudaStreamCreateWithFlags(&g_cudaStream2, cudaStreamNonBlocking));
322+
CUDA_CHECK(cudaEventCreateWithFlags(&g_cudaEvent, cudaEventDisableTiming));
323+
301324
// configure logging to stdout
302325
// CUBLAS_CHECK(cublasLoggerConfigure(1, 1, 0, NULL));
303326
}

ggml-cuda.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,9 @@ extern "C" {
2626
} while (0)
2727

2828
extern cublasHandle_t g_cublasH;
29-
extern cudaStream_t g_cudaStream;
29+
extern cudaStream_t g_cudaStream;
30+
extern cudaStream_t g_cudaStream2;
31+
extern cudaEvent_t g_cudaEvent;
3032

3133
void ggml_init_cublas(void);
3234
void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size);
@@ -41,6 +43,9 @@ void dequantize_row_q8_0_cuda(const void * vx, float * y, int k, cudaStream_t st
4143

4244
cudaError_t ggml_cuda_h2d_tensor_2d(void * dst, const struct ggml_tensor * src, uint64_t i3, uint64_t i2, cudaStream_t stream);
4345

46+
typedef void (*dequantize_row_q_cuda_t)(const void * x, float * y, int k, cudaStream_t stream);
47+
dequantize_row_q_cuda_t ggml_get_dequantize_row_q_cuda(enum ggml_type type);
48+
4449
#ifdef __cplusplus
4550
}
4651
#endif

ggml.c

Lines changed: 14 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -8033,7 +8033,7 @@ static void ggml_compute_forward_mul_mat_f32(
80338033
#if defined(GGML_USE_CUBLAS)
80348034
const float alpha = 1.0f;
80358035
const float beta = 0.0f;
8036-
const int x_ne = ne01 * ne10;
8036+
const int x_ne = ne01 * ne00;
80378037
const int y_ne = ne11 * ne10;
80388038
const int d_ne = ne11 * ne01;
80398039

@@ -8239,7 +8239,7 @@ static void ggml_compute_forward_mul_mat_f16_f32(
82398239

82408240
const float alpha = 1.0f;
82418241
const float beta = 0.0f;
8242-
const int x_ne = ne01 * ne10;
8242+
const int x_ne = ne01 * ne00;
82438243
const int y_ne = ne11 * ne10;
82448244
const int d_ne = ne11 * ne01;
82458245

@@ -8498,39 +8498,19 @@ static void ggml_compute_forward_mul_mat_q_f32(
84988498
#if defined(GGML_USE_CUBLAS)
84998499
const float alpha = 1.0f;
85008500
const float beta = 0.0f;
8501-
const int x_ne = ne01 * ne10;
8501+
const int x_ne = ne01 * ne00;
85028502
const int y_ne = ne11 * ne10;
85038503
const int d_ne = ne11 * ne01;
85048504

85058505
size_t x_size, y_size, d_size, q_size;
8506-
float *d_X = ggml_cuda_pool_malloc(sizeof(float) * x_ne, &x_size);
8507-
float *d_Y = ggml_cuda_pool_malloc(sizeof(float) * y_ne, &y_size);
8508-
float *d_D = ggml_cuda_pool_malloc(sizeof(float) * d_ne, &d_size);
8509-
float *d_Q = ggml_cuda_pool_malloc(GGML_TYPE_SIZE[type] * x_ne / GGML_BLCK_SIZE[type], &q_size);
8506+
float * d_X = ggml_cuda_pool_malloc(sizeof(float) * x_ne, &x_size);
8507+
float * d_Y = ggml_cuda_pool_malloc(sizeof(float) * y_ne, &y_size);
8508+
float * d_D = ggml_cuda_pool_malloc(sizeof(float) * d_ne, &d_size);
8509+
void * d_Q = ggml_cuda_pool_malloc(GGML_TYPE_SIZE[type] * x_ne / GGML_BLCK_SIZE[type], &q_size);
85108510

8511-
void (*dequantize_row_q_cuda)(const void * x, float * y, int k, cudaStream_t stream) = NULL;
8512-
if (type == GGML_TYPE_Q4_0) {
8513-
dequantize_row_q_cuda = dequantize_row_q4_0_cuda;
8514-
}
8515-
else if (type == GGML_TYPE_Q4_1) {
8516-
dequantize_row_q_cuda = dequantize_row_q4_1_cuda;
8517-
}
8518-
else if (type == GGML_TYPE_Q4_2) {
8519-
dequantize_row_q_cuda = dequantize_row_q4_2_cuda;
8520-
}
8521-
else if (type == GGML_TYPE_Q5_0) {
8522-
dequantize_row_q_cuda = dequantize_row_q5_0_cuda;
8523-
}
8524-
else if (type == GGML_TYPE_Q5_1) {
8525-
dequantize_row_q_cuda = dequantize_row_q5_1_cuda;
8526-
}
8527-
else if (type == GGML_TYPE_Q8_0) {
8528-
dequantize_row_q_cuda = dequantize_row_q8_0_cuda;
8529-
}
8530-
else {
8531-
GGML_ASSERT(false);
8532-
}
8533-
#elif !defined(GGML_USE_CLBLAST)
8511+
const dequantize_row_q_cuda_t dequantize_row_q_cuda = ggml_get_dequantize_row_q_cuda(type);
8512+
GGML_ASSERT(dequantize_row_q_cuda != NULL);
8513+
#else
85348514
float * const wdata = params->wdata;
85358515
dequantize_row_q_t const dequantize_row_q = quantize_fns[type].dequantize_row_q;
85368516
#endif
@@ -8545,7 +8525,7 @@ static void ggml_compute_forward_mul_mat_q_f32(
85458525
// copy and dequantize on device
85468526
CUDA_CHECK(ggml_cuda_h2d_tensor_2d(d_Q, src0, i03, i02, g_cudaStream));
85478527

8548-
dequantize_row_q_cuda(d_Q, d_X, ne01 * ne00, g_cudaStream);
8528+
dequantize_row_q_cuda(d_Q, d_X, x_ne, g_cudaStream2);
85498529
CUDA_CHECK(cudaGetLastError());
85508530
#elif defined(GGML_USE_CLBLAST)
85518531
const void* x = (char *) src0->data + i03*nb03 + i02*nb02;
@@ -8565,6 +8545,9 @@ static void ggml_compute_forward_mul_mat_q_f32(
85658545
// copy data to device
85668546
CUDA_CHECK(ggml_cuda_h2d_tensor_2d(d_Y, src1, i03, i02, g_cudaStream));
85678547

8548+
// wait for dequantization
8549+
CUDA_CHECK(cudaStreamWaitEvent(g_cudaStream, g_cudaEvent, 0));
8550+
85688551
// compute
85698552
CUBLAS_CHECK(
85708553
cublasSgemm(g_cublasH, CUBLAS_OP_T, CUBLAS_OP_N,

0 commit comments

Comments
 (0)