@@ -8401,15 +8401,19 @@ static void ggml_compute_forward_mul_mat_f16_f32(
8401
8401
const int d_ne = ne11 * ne01 ;
8402
8402
8403
8403
size_t x_size , y_size , d_size ;
8404
- float * d_X = ggml_cuda_pool_malloc (sizeof (float ) * x_ne , & x_size );
8405
- float * d_Y = ggml_cuda_pool_malloc (sizeof (float ) * y_ne , & y_size );
8406
- float * d_D = ggml_cuda_pool_malloc (sizeof (float ) * d_ne , & d_size );
8404
+ ggml_fp16_t * d_X = ggml_cuda_pool_malloc (sizeof (float ) * x_ne , & x_size );
8405
+ ggml_fp16_t * d_Y = ggml_cuda_pool_malloc (sizeof (float ) * y_ne , & y_size );
8406
+ float * d_D = ggml_cuda_pool_malloc (sizeof (float ) * d_ne , & d_size );
8407
8407
#else
8408
8408
float * const wdata = params -> wdata ;
8409
8409
#endif
8410
8410
for (int64_t i03 = 0 ; i03 < ne03 ; i03 ++ ) {
8411
8411
for (int64_t i02 = 0 ; i02 < ne02 ; i02 ++ ) {
8412
8412
#if defined(GGML_USE_CUBLAS )
8413
+ // copy src0 while converting src1
8414
+ const ggml_fp16_t * x = (ggml_fp16_t * ) ((char * ) src0 -> data + i02 * nb02 + i03 * nb03 );
8415
+ CUDA_CHECK (cudaMemcpyAsync (d_X , x , sizeof (ggml_fp16_t ) * x_ne , cudaMemcpyHostToDevice , g_cudaStream ));
8416
+
8413
8417
// with cuBlAS, instead of converting src0 to fp32, we convert src1 to fp16
8414
8418
ggml_fp16_t * const wdata = (ggml_fp16_t * ) params -> wdata + (ne11 * ne10 ) * (i03 * ne02 + i02 );
8415
8419
{
@@ -8432,13 +8436,10 @@ static void ggml_compute_forward_mul_mat_f16_f32(
8432
8436
#endif
8433
8437
8434
8438
#if defined(GGML_USE_CUBLAS )
8435
- const ggml_fp16_t * x = (ggml_fp16_t * ) ((char * ) src0 -> data + i02 * nb02 + i03 * nb03 );
8436
8439
const ggml_fp16_t * y = (ggml_fp16_t * ) wdata ;
8437
-
8438
8440
float * d = (float * ) ((char * ) dst -> data + i02 * nb2 + i03 * nb3 );
8439
8441
8440
8442
// copy data to device
8441
- CUDA_CHECK (cudaMemcpyAsync (d_X , x , sizeof (ggml_fp16_t ) * x_ne , cudaMemcpyHostToDevice , g_cudaStream ));
8442
8443
CUDA_CHECK (cudaMemcpyAsync (d_Y , y , sizeof (ggml_fp16_t ) * y_ne , cudaMemcpyHostToDevice , g_cudaStream ));
8443
8444
8444
8445
// compute
0 commit comments