@@ -8419,15 +8419,19 @@ static void ggml_compute_forward_mul_mat_f16_f32(
8419
8419
const int d_ne = ne11 * ne01 ;
8420
8420
8421
8421
size_t x_size , y_size , d_size ;
8422
- float * d_X = ggml_cuda_pool_malloc (sizeof (float ) * x_ne , & x_size );
8423
- float * d_Y = ggml_cuda_pool_malloc (sizeof (float ) * y_ne , & y_size );
8424
- float * d_D = ggml_cuda_pool_malloc (sizeof (float ) * d_ne , & d_size );
8422
+ ggml_fp16_t * d_X = ggml_cuda_pool_malloc (sizeof (float ) * x_ne , & x_size );
8423
+ ggml_fp16_t * d_Y = ggml_cuda_pool_malloc (sizeof (float ) * y_ne , & y_size );
8424
+ float * d_D = ggml_cuda_pool_malloc (sizeof (float ) * d_ne , & d_size );
8425
8425
#else
8426
8426
float * const wdata = params -> wdata ;
8427
8427
#endif
8428
8428
for (int64_t i03 = 0 ; i03 < ne03 ; i03 ++ ) {
8429
8429
for (int64_t i02 = 0 ; i02 < ne02 ; i02 ++ ) {
8430
8430
#if defined(GGML_USE_CUBLAS )
8431
+ // copy src0 while converting src1
8432
+ const ggml_fp16_t * x = (ggml_fp16_t * ) ((char * ) src0 -> data + i02 * nb02 + i03 * nb03 );
8433
+ CUDA_CHECK (cudaMemcpyAsync (d_X , x , sizeof (ggml_fp16_t ) * x_ne , cudaMemcpyHostToDevice , g_cudaStream ));
8434
+
8431
8435
// with cuBlAS, instead of converting src0 to fp32, we convert src1 to fp16
8432
8436
ggml_fp16_t * const wdata = (ggml_fp16_t * ) params -> wdata + (ne11 * ne10 ) * (i03 * ne02 + i02 );
8433
8437
{
@@ -8450,13 +8454,10 @@ static void ggml_compute_forward_mul_mat_f16_f32(
8450
8454
#endif
8451
8455
8452
8456
#if defined(GGML_USE_CUBLAS )
8453
- const ggml_fp16_t * x = (ggml_fp16_t * ) ((char * ) src0 -> data + i02 * nb02 + i03 * nb03 );
8454
8457
const ggml_fp16_t * y = (ggml_fp16_t * ) wdata ;
8455
-
8456
8458
float * d = (float * ) ((char * ) dst -> data + i02 * nb2 + i03 * nb3 );
8457
8459
8458
8460
// copy data to device
8459
- CUDA_CHECK (cudaMemcpyAsync (d_X , x , sizeof (ggml_fp16_t ) * x_ne , cudaMemcpyHostToDevice , g_cudaStream ));
8460
8461
CUDA_CHECK (cudaMemcpyAsync (d_Y , y , sizeof (ggml_fp16_t ) * y_ne , cudaMemcpyHostToDevice , g_cudaStream ));
8461
8462
8462
8463
// compute
0 commit comments