Skip to content

Commit cf93fdc

Browse files
committed
cuBLAS: refactor, convert fp16 to fp32 on device
1 parent b925f1f commit cf93fdc

File tree

4 files changed

+287
-249
lines changed

4 files changed

+287
-249
lines changed

ggml-cuda.cu

Lines changed: 223 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,37 @@
1+
#include <cstdint>
12
#include <stdint.h>
23
#include <stdio.h>
3-
#include <cuda_fp16.h>
44
#include <atomic>
5-
#include "ggml-cuda.h"
65

7-
typedef uint16_t ggml_fp16_t;
8-
static_assert(sizeof(__half) == sizeof(ggml_fp16_t), "wrong fp16 size");
6+
#include <cuda_runtime.h>
7+
#include <cublas_v2.h>
8+
#include <cuda_fp16.h>
9+
10+
#include "ggml-cuda.h"
11+
#include "ggml.h"
12+
13+
static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
14+
15+
#define CUDA_CHECK(err) \
16+
do { \
17+
cudaError_t err_ = (err); \
18+
if (err_ != cudaSuccess) { \
19+
fprintf(stderr, "CUDA error %d at %s:%d: %s\n", err_, __FILE__, __LINE__, \
20+
cudaGetErrorString(err_)); \
21+
exit(1); \
22+
} \
23+
} while (0)
24+
25+
#define CUBLAS_CHECK(err) \
26+
do { \
27+
cublasStatus_t err_ = (err); \
28+
if (err_ != CUBLAS_STATUS_SUCCESS) { \
29+
fprintf(stderr, "cuBLAS error %d at %s:%d\n", err_, __FILE__, __LINE__); \
30+
exit(1); \
31+
} \
32+
} while (0)
33+
34+
typedef void (*to_fp32_cuda_t)(const void * x, float * y, int k, cudaStream_t stream);
935

1036
#define QK4_0 32
1137
typedef struct {
@@ -24,23 +50,23 @@ static_assert(sizeof(block_q4_1) == sizeof(float) * 2 + QK4_1 / 2, "wrong q4_1 b
2450

2551
#define QK4_2 16
2652
typedef struct {
27-
__half d; // delta
53+
half d; // delta
2854
uint8_t qs[QK4_2 / 2]; // nibbles / quants
2955
} block_q4_2;
3056
static_assert(sizeof(block_q4_2) == sizeof(ggml_fp16_t) + QK4_2 / 2, "wrong q4_2 block size/padding");
3157

3258
#define QK5_0 32
3359
typedef struct {
34-
__half d; // delta
60+
half d; // delta
3561
uint8_t qh[4]; // 5-th bit of quants
3662
uint8_t qs[QK5_0 / 2]; // nibbles / quants
3763
} block_q5_0;
3864
static_assert(sizeof(block_q5_0) == sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_0 / 2, "wrong q5_0 block size/padding");
3965

4066
#define QK5_1 32
4167
typedef struct {
42-
__half d; // delta
43-
__half m; // min
68+
half d; // delta
69+
half m; // min
4470
uint32_t qh; // 5-th bit of quants
4571
uint8_t qs[QK5_1 / 2]; // nibbles / quants
4672
} block_q5_1;
@@ -197,37 +223,49 @@ static __global__ void dequantize_block_q8_0(const void * vx, float * y) {
197223
}
198224
}
199225

200-
void dequantize_row_q4_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
226+
static void dequantize_row_q4_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
201227
const int nb = k / QK4_0;
202228
dequantize_block_q4_0<<<nb, 1, 0, stream>>>(vx, y);
203229
}
204230

205-
void dequantize_row_q4_1_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
231+
static void dequantize_row_q4_1_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
206232
const int nb = k / QK4_1;
207233
dequantize_block_q4_1<<<nb, 1, 0, stream>>>(vx, y);
208234
}
209235

210-
void dequantize_row_q4_2_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
236+
static void dequantize_row_q4_2_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
211237
const int nb = k / QK4_2;
212238
dequantize_block_q4_2<<<nb, 1, 0, stream>>>(vx, y);
213239
}
214240

215-
void dequantize_row_q5_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
241+
static void dequantize_row_q5_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
216242
const int nb = k / QK5_0;
217243
dequantize_block_q5_0<<<nb, 1, 0, stream>>>(vx, y);
218244
}
219245

220-
void dequantize_row_q5_1_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
246+
static void dequantize_row_q5_1_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
221247
const int nb = k / QK5_1;
222248
dequantize_block_q5_1<<<nb, 1, 0, stream>>>(vx, y);
223249
}
224250

225-
void dequantize_row_q8_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
251+
static void dequantize_row_q8_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
226252
const int nb = k / QK8_0;
227253
dequantize_block_q8_0<<<nb, 1, 0, stream>>>(vx, y);
228254
}
229255

230-
dequantize_row_q_cuda_t ggml_get_dequantize_row_q_cuda(ggml_type type) {
256+
static __global__ void convert_fp16_to_fp32(const void * vx, float * y) {
257+
const half * x = (const half *) vx;
258+
259+
const int i = blockIdx.x;
260+
261+
y[i] = __half2float(x[i]);
262+
}
263+
264+
static void convert_fp16_to_fp32_cuda(const void * x, float * y, int k, cudaStream_t stream) {
265+
convert_fp16_to_fp32<<<k, 1, 0, stream>>>(x, y);
266+
}
267+
268+
static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
231269
switch (type) {
232270
case GGML_TYPE_Q4_0:
233271
return dequantize_row_q4_0_cuda;
@@ -241,6 +279,8 @@ dequantize_row_q_cuda_t ggml_get_dequantize_row_q_cuda(ggml_type type) {
241279
return dequantize_row_q5_1_cuda;
242280
case GGML_TYPE_Q8_0:
243281
return dequantize_row_q8_0_cuda;
282+
case GGML_TYPE_F16:
283+
return convert_fp16_to_fp32_cuda;
244284
default:
245285
return nullptr;
246286
}
@@ -271,7 +311,7 @@ struct cuda_buffer {
271311
static cuda_buffer g_cuda_buffer_pool[MAX_CUDA_BUFFERS];
272312
static std::atomic_flag g_cuda_pool_lock = ATOMIC_FLAG_INIT;
273313

274-
void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) {
314+
static void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) {
275315
scoped_spin_lock lock(g_cuda_pool_lock);
276316

277317
for (int i = 0; i < MAX_CUDA_BUFFERS; ++i) {
@@ -290,7 +330,7 @@ void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) {
290330
return ptr;
291331
}
292332

293-
void ggml_cuda_pool_free(void * ptr, size_t size) {
333+
static void ggml_cuda_pool_free(void * ptr, size_t size) {
294334
scoped_spin_lock lock(g_cuda_pool_lock);
295335

296336
for (int i = 0; i < MAX_CUDA_BUFFERS; ++i) {
@@ -305,17 +345,19 @@ void ggml_cuda_pool_free(void * ptr, size_t size) {
305345
CUDA_CHECK(cudaFree(ptr));
306346
}
307347

308-
cublasHandle_t g_cublasH = nullptr;
309-
cudaStream_t g_cudaStream = nullptr;
310-
cudaStream_t g_cudaStream2 = nullptr;
311-
cudaEvent_t g_cudaEvent = nullptr;
348+
static cublasHandle_t g_cublasH = nullptr;
349+
static cudaStream_t g_cudaStream = nullptr;
350+
static cudaStream_t g_cudaStream2 = nullptr;
351+
static cudaEvent_t g_cudaEvent = nullptr;
312352

313353
void ggml_init_cublas() {
314354
if (g_cublasH == nullptr) {
315355
// create cublas handle, bind a stream
316356
CUBLAS_CHECK(cublasCreate(&g_cublasH));
317357
CUDA_CHECK(cudaStreamCreateWithFlags(&g_cudaStream, cudaStreamNonBlocking));
318358
CUBLAS_CHECK(cublasSetStream(g_cublasH, g_cudaStream));
359+
// enable tensor cores
360+
CUBLAS_CHECK(cublasSetMathMode(g_cublasH, CUBLAS_TENSOR_OP_MATH));
319361

320362
// create additional stream and event for synchronization
321363
CUDA_CHECK(cudaStreamCreateWithFlags(&g_cudaStream2, cudaStreamNonBlocking));
@@ -326,7 +368,27 @@ void ggml_init_cublas() {
326368
}
327369
}
328370

329-
cudaError_t ggml_cuda_h2d_tensor_2d(void * dst, const struct ggml_tensor * src, uint64_t i3, uint64_t i2, cudaStream_t stream) {
371+
void * ggml_cuda_host_malloc(size_t size) {
372+
if (getenv("GGML_CUDA_NO_PINNED") != nullptr) {
373+
return nullptr;
374+
}
375+
376+
void * ptr = nullptr;
377+
cudaError_t err = cudaMallocHost((void **) &ptr, size);
378+
if (err != cudaSuccess) {
379+
fprintf(stderr, "WARNING: failed to allocate %.2f MB of pinned memory: %s\n",
380+
size/1024.0/1024.0, cudaGetErrorString(err));
381+
return nullptr;
382+
}
383+
384+
return ptr;
385+
}
386+
387+
void ggml_cuda_host_free(void * ptr) {
388+
CUDA_CHECK(cudaFreeHost(ptr));
389+
}
390+
391+
static cudaError_t ggml_cuda_h2d_tensor_2d(void * dst, const struct ggml_tensor * src, uint64_t i3, uint64_t i2, cudaStream_t stream) {
330392
const uint64_t ne0 = src->ne[0];
331393
const uint64_t ne1 = src->ne[1];
332394
const uint64_t nb0 = src->nb[0];
@@ -354,22 +416,149 @@ cudaError_t ggml_cuda_h2d_tensor_2d(void * dst, const struct ggml_tensor * src,
354416
}
355417
}
356418

357-
void * ggml_cuda_host_malloc(size_t size) {
358-
if (getenv("GGML_CUDA_NO_PINNED") != nullptr) {
359-
return nullptr;
419+
static void ggml_cuda_mul_mat_f32(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
420+
const int64_t ne00 = src0->ne[0];
421+
const int64_t ne01 = src0->ne[1];
422+
const int64_t ne02 = src0->ne[2];
423+
const int64_t ne03 = src0->ne[3];
424+
425+
const int64_t ne10 = src1->ne[0];
426+
const int64_t ne11 = src1->ne[1];
427+
428+
const int nb2 = dst->nb[2];
429+
const int nb3 = dst->nb[3];
430+
431+
const float alpha = 1.0f;
432+
const float beta = 0.0f;
433+
const int x_ne = ne01 * ne00;
434+
const int y_ne = ne11 * ne10;
435+
const int d_ne = ne11 * ne01;
436+
437+
size_t x_size, y_size, d_size;
438+
float * d_X = (float *) ggml_cuda_pool_malloc(sizeof(float) * x_ne, &x_size);
439+
float * d_Y = (float *) ggml_cuda_pool_malloc(sizeof(float) * y_ne, &y_size);
440+
float * d_D = (float *) ggml_cuda_pool_malloc(sizeof(float) * d_ne, &d_size);
441+
442+
for (int64_t i03 = 0; i03 < ne03; i03++) {
443+
for (int64_t i02 = 0; i02 < ne02; i02++) {
444+
// copy data to device
445+
CUDA_CHECK(ggml_cuda_h2d_tensor_2d(d_X, src0, i03, i02, g_cudaStream));
446+
CUDA_CHECK(ggml_cuda_h2d_tensor_2d(d_Y, src1, i03, i02, g_cudaStream));
447+
448+
// compute
449+
CUBLAS_CHECK(
450+
cublasSgemm(g_cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
451+
ne01, ne11, ne10,
452+
&alpha, d_X, ne00,
453+
d_Y, ne10,
454+
&beta, d_D, ne01));
455+
456+
// copy data to host
457+
float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
458+
CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, g_cudaStream));
459+
}
360460
}
361461

362-
void * ptr = nullptr;
363-
cudaError_t err = cudaMallocHost((void **) &ptr, size);
364-
if (err != cudaSuccess) {
365-
fprintf(stderr, "WARNING: failed to allocate %.2f MB of pinned memory: %s\n",
366-
size/1024.0/1024.0, cudaGetErrorString(err));
367-
return nullptr;
462+
CUDA_CHECK(cudaStreamSynchronize(g_cudaStream));
463+
ggml_cuda_pool_free(d_X, x_size);
464+
ggml_cuda_pool_free(d_Y, y_size);
465+
ggml_cuda_pool_free(d_D, d_size);
466+
}
467+
468+
static void ggml_cuda_mul_mat_q(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
469+
const int64_t ne00 = src0->ne[0];
470+
const int64_t ne01 = src0->ne[1];
471+
const int64_t ne02 = src0->ne[2];
472+
const int64_t ne03 = src0->ne[3];
473+
474+
const int64_t ne10 = src1->ne[0];
475+
const int64_t ne11 = src1->ne[1];
476+
477+
const int nb2 = dst->nb[2];
478+
const int nb3 = dst->nb[3];
479+
const ggml_type type = src0->type;
480+
481+
const float alpha = 1.0f;
482+
const float beta = 0.0f;
483+
const int x_ne = ne01 * ne00;
484+
const int y_ne = ne11 * ne10;
485+
const int d_ne = ne11 * ne01;
486+
487+
size_t x_size, y_size, d_size, q_size;
488+
float * d_X = (float *) ggml_cuda_pool_malloc(sizeof(float) * x_ne, &x_size);
489+
float * d_Y = (float *) ggml_cuda_pool_malloc(sizeof(float) * y_ne, &y_size);
490+
float * d_D = (float *) ggml_cuda_pool_malloc(sizeof(float) * d_ne, &d_size);
491+
void * d_Q = (void *) ggml_cuda_pool_malloc(ggml_type_size(type) * x_ne / ggml_blck_size(type), &q_size);
492+
493+
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(type);
494+
GGML_ASSERT(to_fp32_cuda != NULL);
495+
496+
for (int64_t i03 = 0; i03 < ne03; i03++) {
497+
for (int64_t i02 = 0; i02 < ne02; i02++) {
498+
// copy and convert to fp32 on device
499+
CUDA_CHECK(ggml_cuda_h2d_tensor_2d(d_Q, src0, i03, i02, g_cudaStream2));
500+
501+
to_fp32_cuda(d_Q, d_X, x_ne, g_cudaStream2);
502+
CUDA_CHECK(cudaGetLastError());
503+
CUDA_CHECK(cudaEventRecord(g_cudaEvent, g_cudaStream2));
504+
505+
// copy data to device
506+
CUDA_CHECK(ggml_cuda_h2d_tensor_2d(d_Y, src1, i03, i02, g_cudaStream));
507+
508+
// wait for conversion
509+
CUDA_CHECK(cudaStreamWaitEvent(g_cudaStream, g_cudaEvent, 0));
510+
511+
// compute
512+
CUBLAS_CHECK(
513+
cublasSgemm(g_cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
514+
ne01, ne11, ne10,
515+
&alpha, d_X, ne00,
516+
d_Y, ne10,
517+
&beta, d_D, ne01));
518+
519+
// copy data to host
520+
float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
521+
CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, g_cudaStream));
522+
}
368523
}
369524

370-
return ptr;
525+
CUDA_CHECK(cudaStreamSynchronize(g_cudaStream));
526+
ggml_cuda_pool_free(d_X, x_size);
527+
ggml_cuda_pool_free(d_Y, y_size);
528+
ggml_cuda_pool_free(d_D, d_size);
529+
ggml_cuda_pool_free(d_Q, q_size);
371530
}
372531

373-
void ggml_cuda_host_free(void * ptr) {
374-
CUDA_CHECK(cudaFreeHost(ptr));
532+
bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) {
533+
const int64_t ne10 = src1->ne[0];
534+
535+
const int64_t ne0 = dst->ne[0];
536+
const int64_t ne1 = dst->ne[1];
537+
538+
// TODO: find the optimal values for these
539+
if ((src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) &&
540+
src1->type == GGML_TYPE_F32 &&
541+
dst->type == GGML_TYPE_F32 &&
542+
(ne0 >= 32 && ne1 >= 32 && ne10 >= 32)) {
543+
544+
return true;
545+
}
546+
547+
return false;
548+
}
549+
550+
void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
551+
GGML_ASSERT(ggml_cuda_can_mul_mat(src0, src1, dst));
552+
553+
const ggml_type type = src0->type;
554+
555+
if (type == GGML_TYPE_F32) {
556+
ggml_cuda_mul_mat_f32(src0, src1, dst);
557+
}
558+
else if (type == GGML_TYPE_F16 || ggml_is_quantized(type)) {
559+
ggml_cuda_mul_mat_q(src0, src1, dst);
560+
}
561+
else {
562+
GGML_ASSERT(false);
563+
}
375564
}

0 commit comments

Comments
 (0)