@@ -8200,7 +8200,7 @@ static void ggml_compute_forward_mul_mat_f32(
8200
8200
#if defined(GGML_USE_CUBLAS )
8201
8201
const float alpha = 1.0f ;
8202
8202
const float beta = 0.0f ;
8203
- const int x_ne = ne01 * ne10 ;
8203
+ const int x_ne = ne01 * ne00 ;
8204
8204
const int y_ne = ne11 * ne10 ;
8205
8205
const int d_ne = ne11 * ne01 ;
8206
8206
@@ -8398,7 +8398,7 @@ static void ggml_compute_forward_mul_mat_f16_f32(
8398
8398
8399
8399
const float alpha = 1.0f ;
8400
8400
const float beta = 0.0f ;
8401
- const int x_ne = ne01 * ne10 ;
8401
+ const int x_ne = ne01 * ne00 ;
8402
8402
const int y_ne = ne11 * ne10 ;
8403
8403
const int d_ne = ne11 * ne01 ;
8404
8404
@@ -8645,41 +8645,18 @@ static void ggml_compute_forward_mul_mat_q_f32(
8645
8645
#if defined(GGML_USE_CUBLAS )
8646
8646
const float alpha = 1.0f ;
8647
8647
const float beta = 0.0f ;
8648
- const int x_ne = ne01 * ne10 ;
8648
+ const int x_ne = ne01 * ne00 ;
8649
8649
const int y_ne = ne11 * ne10 ;
8650
8650
const int d_ne = ne11 * ne01 ;
8651
8651
8652
8652
size_t x_size , y_size , d_size , q_size ;
8653
- float * d_X = ggml_cuda_pool_malloc (sizeof (float ) * x_ne , & x_size );
8654
- float * d_Y = ggml_cuda_pool_malloc (sizeof (float ) * y_ne , & y_size );
8655
- float * d_D = ggml_cuda_pool_malloc (sizeof (float ) * d_ne , & d_size );
8656
- float * d_Q = ggml_cuda_pool_malloc (GGML_TYPE_SIZE [type ] * x_ne / GGML_BLCK_SIZE [type ], & q_size );
8653
+ float * d_X = ggml_cuda_pool_malloc (sizeof (float ) * x_ne , & x_size );
8654
+ float * d_Y = ggml_cuda_pool_malloc (sizeof (float ) * y_ne , & y_size );
8655
+ float * d_D = ggml_cuda_pool_malloc (sizeof (float ) * d_ne , & d_size );
8656
+ void * d_Q = ggml_cuda_pool_malloc (GGML_TYPE_SIZE [type ] * x_ne / GGML_BLCK_SIZE [type ], & q_size );
8657
8657
8658
- void (* dequantize_row_q_cuda )(const void * x , float * y , int k , cudaStream_t stream ) = NULL ;
8659
- if (type == GGML_TYPE_Q4_0 ) {
8660
- dequantize_row_q_cuda = dequantize_row_q4_0_cuda ;
8661
- }
8662
- else if (type == GGML_TYPE_Q4_1 ) {
8663
- dequantize_row_q_cuda = dequantize_row_q4_1_cuda ;
8664
- }
8665
- else if (type == GGML_TYPE_Q4_2 ) {
8666
- dequantize_row_q_cuda = dequantize_row_q4_2_cuda ;
8667
- }
8668
- else if (type == GGML_TYPE_Q4_3 ) {
8669
- dequantize_row_q_cuda = dequantize_row_q4_3_cuda ;
8670
- }
8671
- else if (type == GGML_TYPE_Q5_0 ) {
8672
- dequantize_row_q_cuda = dequantize_row_q5_0_cuda ;
8673
- }
8674
- else if (type == GGML_TYPE_Q5_1 ) {
8675
- dequantize_row_q_cuda = dequantize_row_q5_1_cuda ;
8676
- }
8677
- else if (type == GGML_TYPE_Q8_0 ) {
8678
- dequantize_row_q_cuda = dequantize_row_q8_0_cuda ;
8679
- }
8680
- else {
8681
- GGML_ASSERT (false);
8682
- }
8658
+ const dequantize_row_q_cuda_t dequantize_row_q_cuda = ggml_get_dequantize_row_q_cuda (type );
8659
+ GGML_ASSERT (dequantize_row_q_cuda != NULL );
8683
8660
#else
8684
8661
float * const wdata = params -> wdata ;
8685
8662
dequantize_row_q_t const dequantize_row_q = quantize_fns [type ].dequantize_row_q ;
@@ -8695,10 +8672,11 @@ static void ggml_compute_forward_mul_mat_q_f32(
8695
8672
// copy and dequantize on device
8696
8673
CUDA_CHECK (
8697
8674
cudaMemcpyAsync (d_Q , (char * ) src0 -> data + i03 * nb03 + i02 * nb02 ,
8698
- GGML_TYPE_SIZE [type ] * x_ne / GGML_BLCK_SIZE [type ], cudaMemcpyHostToDevice , g_cudaStream ));
8675
+ GGML_TYPE_SIZE [type ] * x_ne / GGML_BLCK_SIZE [type ], cudaMemcpyHostToDevice , g_cudaStream2 ));
8699
8676
8700
- dequantize_row_q_cuda (d_Q , d_X , ne01 * ne00 , g_cudaStream );
8677
+ dequantize_row_q_cuda (d_Q , d_X , x_ne , g_cudaStream2 );
8701
8678
CUDA_CHECK (cudaGetLastError ());
8679
+ CUDA_CHECK (cudaEventRecord (g_cudaEvent , g_cudaStream2 ));
8702
8680
#else
8703
8681
{
8704
8682
size_t id = 0 ;
@@ -8715,6 +8693,9 @@ static void ggml_compute_forward_mul_mat_q_f32(
8715
8693
// copy data to device
8716
8694
CUDA_CHECK (cudaMemcpyAsync (d_Y , y , sizeof (float ) * y_ne , cudaMemcpyHostToDevice , g_cudaStream ));
8717
8695
8696
+ // wait for dequantization
8697
+ CUDA_CHECK (cudaStreamWaitEvent (g_cudaStream , g_cudaEvent , 0 ));
8698
+
8718
8699
// compute
8719
8700
CUBLAS_CHECK (
8720
8701
cublasSgemm (g_cublasH , CUBLAS_OP_T , CUBLAS_OP_N ,
0 commit comments