Skip to content

Commit f8d8377

Browse files
committed
cuBLAS: improve ggml_compute_forward_mul_mat_f16_f32 with pinned memory
1 parent 0bb9613 commit f8d8377

File tree

1 file changed

+7
-6
lines changed

1 file changed

+7
-6
lines changed

ggml.c

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8419,15 +8419,19 @@ static void ggml_compute_forward_mul_mat_f16_f32(
84198419
const int d_ne = ne11 * ne01;
84208420

84218421
size_t x_size, y_size, d_size;
8422-
float *d_X = ggml_cuda_pool_malloc(sizeof(float) * x_ne, &x_size);
8423-
float *d_Y = ggml_cuda_pool_malloc(sizeof(float) * y_ne, &y_size);
8424-
float *d_D = ggml_cuda_pool_malloc(sizeof(float) * d_ne, &d_size);
8422+
ggml_fp16_t * d_X = ggml_cuda_pool_malloc(sizeof(float) * x_ne, &x_size);
8423+
ggml_fp16_t * d_Y = ggml_cuda_pool_malloc(sizeof(float) * y_ne, &y_size);
8424+
float * d_D = ggml_cuda_pool_malloc(sizeof(float) * d_ne, &d_size);
84258425
#else
84268426
float * const wdata = params->wdata;
84278427
#endif
84288428
for (int64_t i03 = 0; i03 < ne03; i03++) {
84298429
for (int64_t i02 = 0; i02 < ne02; i02++) {
84308430
#if defined(GGML_USE_CUBLAS)
8431+
// copy src0 while converting src1
8432+
const ggml_fp16_t * x = (ggml_fp16_t *) ((char *) src0->data + i02*nb02 + i03*nb03);
8433+
CUDA_CHECK(cudaMemcpyAsync(d_X, x, sizeof(ggml_fp16_t) * x_ne, cudaMemcpyHostToDevice, g_cudaStream));
8434+
84318435
// with cuBlAS, instead of converting src0 to fp32, we convert src1 to fp16
84328436
ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + (ne11 * ne10) * (i03 * ne02 + i02);
84338437
{
@@ -8450,13 +8454,10 @@ static void ggml_compute_forward_mul_mat_f16_f32(
84508454
#endif
84518455

84528456
#if defined(GGML_USE_CUBLAS)
8453-
const ggml_fp16_t * x = (ggml_fp16_t *) ((char *) src0->data + i02*nb02 + i03*nb03);
84548457
const ggml_fp16_t * y = (ggml_fp16_t *) wdata;
8455-
84568458
float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
84578459

84588460
// copy data to device
8459-
CUDA_CHECK(cudaMemcpyAsync(d_X, x, sizeof(ggml_fp16_t) * x_ne, cudaMemcpyHostToDevice, g_cudaStream));
84608461
CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(ggml_fp16_t) * y_ne, cudaMemcpyHostToDevice, g_cudaStream));
84618462

84628463
// compute

0 commit comments

Comments
 (0)