@@ -8218,7 +8218,7 @@ static void ggml_compute_forward_mul_mat_f32(
8218
8218
#if defined(GGML_USE_CUBLAS )
8219
8219
const float alpha = 1.0f ;
8220
8220
const float beta = 0.0f ;
8221
- const int x_ne = ne01 * ne10 ;
8221
+ const int x_ne = ne01 * ne00 ;
8222
8222
const int y_ne = ne11 * ne10 ;
8223
8223
const int d_ne = ne11 * ne01 ;
8224
8224
@@ -8416,7 +8416,7 @@ static void ggml_compute_forward_mul_mat_f16_f32(
8416
8416
8417
8417
const float alpha = 1.0f ;
8418
8418
const float beta = 0.0f ;
8419
- const int x_ne = ne01 * ne10 ;
8419
+ const int x_ne = ne01 * ne00 ;
8420
8420
const int y_ne = ne11 * ne10 ;
8421
8421
const int d_ne = ne11 * ne01 ;
8422
8422
@@ -8663,41 +8663,18 @@ static void ggml_compute_forward_mul_mat_q_f32(
8663
8663
#if defined(GGML_USE_CUBLAS )
8664
8664
const float alpha = 1.0f ;
8665
8665
const float beta = 0.0f ;
8666
- const int x_ne = ne01 * ne10 ;
8666
+ const int x_ne = ne01 * ne00 ;
8667
8667
const int y_ne = ne11 * ne10 ;
8668
8668
const int d_ne = ne11 * ne01 ;
8669
8669
8670
8670
size_t x_size , y_size , d_size , q_size ;
8671
- float * d_X = ggml_cuda_pool_malloc (sizeof (float ) * x_ne , & x_size );
8672
- float * d_Y = ggml_cuda_pool_malloc (sizeof (float ) * y_ne , & y_size );
8673
- float * d_D = ggml_cuda_pool_malloc (sizeof (float ) * d_ne , & d_size );
8674
- float * d_Q = ggml_cuda_pool_malloc (GGML_TYPE_SIZE [type ] * x_ne / GGML_BLCK_SIZE [type ], & q_size );
8671
+ float * d_X = ggml_cuda_pool_malloc (sizeof (float ) * x_ne , & x_size );
8672
+ float * d_Y = ggml_cuda_pool_malloc (sizeof (float ) * y_ne , & y_size );
8673
+ float * d_D = ggml_cuda_pool_malloc (sizeof (float ) * d_ne , & d_size );
8674
+ void * d_Q = ggml_cuda_pool_malloc (GGML_TYPE_SIZE [type ] * x_ne / GGML_BLCK_SIZE [type ], & q_size );
8675
8675
8676
- void (* dequantize_row_q_cuda )(const void * x , float * y , int k , cudaStream_t stream ) = NULL ;
8677
- if (type == GGML_TYPE_Q4_0 ) {
8678
- dequantize_row_q_cuda = dequantize_row_q4_0_cuda ;
8679
- }
8680
- else if (type == GGML_TYPE_Q4_1 ) {
8681
- dequantize_row_q_cuda = dequantize_row_q4_1_cuda ;
8682
- }
8683
- else if (type == GGML_TYPE_Q4_2 ) {
8684
- dequantize_row_q_cuda = dequantize_row_q4_2_cuda ;
8685
- }
8686
- else if (type == GGML_TYPE_Q4_3 ) {
8687
- dequantize_row_q_cuda = dequantize_row_q4_3_cuda ;
8688
- }
8689
- else if (type == GGML_TYPE_Q5_0 ) {
8690
- dequantize_row_q_cuda = dequantize_row_q5_0_cuda ;
8691
- }
8692
- else if (type == GGML_TYPE_Q5_1 ) {
8693
- dequantize_row_q_cuda = dequantize_row_q5_1_cuda ;
8694
- }
8695
- else if (type == GGML_TYPE_Q8_0 ) {
8696
- dequantize_row_q_cuda = dequantize_row_q8_0_cuda ;
8697
- }
8698
- else {
8699
- GGML_ASSERT (false);
8700
- }
8676
+ const dequantize_row_q_cuda_t dequantize_row_q_cuda = ggml_get_dequantize_row_q_cuda (type );
8677
+ GGML_ASSERT (dequantize_row_q_cuda != NULL );
8701
8678
#else
8702
8679
float * const wdata = params -> wdata ;
8703
8680
dequantize_row_q_t const dequantize_row_q = quantize_fns [type ].dequantize_row_q ;
@@ -8713,10 +8690,11 @@ static void ggml_compute_forward_mul_mat_q_f32(
8713
8690
// copy and dequantize on device
8714
8691
CUDA_CHECK (
8715
8692
cudaMemcpyAsync (d_Q , (char * ) src0 -> data + i03 * nb03 + i02 * nb02 ,
8716
- GGML_TYPE_SIZE [type ] * x_ne / GGML_BLCK_SIZE [type ], cudaMemcpyHostToDevice , g_cudaStream ));
8693
+ GGML_TYPE_SIZE [type ] * x_ne / GGML_BLCK_SIZE [type ], cudaMemcpyHostToDevice , g_cudaStream2 ));
8717
8694
8718
- dequantize_row_q_cuda (d_Q , d_X , ne01 * ne00 , g_cudaStream );
8695
+ dequantize_row_q_cuda (d_Q , d_X , x_ne , g_cudaStream2 );
8719
8696
CUDA_CHECK (cudaGetLastError ());
8697
+ CUDA_CHECK (cudaEventRecord (g_cudaEvent , g_cudaStream2 ));
8720
8698
#else
8721
8699
{
8722
8700
size_t id = 0 ;
@@ -8733,6 +8711,9 @@ static void ggml_compute_forward_mul_mat_q_f32(
8733
8711
// copy data to device
8734
8712
CUDA_CHECK (cudaMemcpyAsync (d_Y , y , sizeof (float ) * y_ne , cudaMemcpyHostToDevice , g_cudaStream ));
8735
8713
8714
+ // wait for dequantization
8715
+ CUDA_CHECK (cudaStreamWaitEvent (g_cudaStream , g_cudaEvent , 0 ));
8716
+
8736
8717
// compute
8737
8718
CUBLAS_CHECK (
8738
8719
cublasSgemm (g_cublasH , CUBLAS_OP_T , CUBLAS_OP_N ,
0 commit comments