@@ -643,6 +643,7 @@ static void ggml_cuda_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor
643
643
const int nb2 = dst->nb [2 ];
644
644
const int nb3 = dst->nb [3 ];
645
645
const ggml_type type = src0->type ;
646
+ const bool mul_mat_vec = ne11 == 1 ;
646
647
647
648
const float alpha = 1 .0f ;
648
649
const float beta = 0 .0f ;
@@ -654,7 +655,7 @@ static void ggml_cuda_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor
654
655
655
656
size_t x_size, y_size, d_size, q_size;
656
657
float * d_X;
657
- if (ne11 > 1 ) {
658
+ if (!mul_mat_vec ) {
658
659
d_X = (float *) ggml_cuda_pool_malloc (n_mm * sizeof (float ) * x_ne, &x_size);
659
660
}
660
661
float * d_Y = (float *) ggml_cuda_pool_malloc (n_mm * sizeof (float ) * y_ne, &y_size);
@@ -684,7 +685,7 @@ static void ggml_cuda_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor
684
685
} else {
685
686
GGML_ASSERT (false );
686
687
}
687
- if (ne11 == 1 ) {
688
+ if (mul_mat_vec ) { // specialized dequantize_mul_mat_vec kernel
688
689
CUDA_CHECK (cudaEventRecord (cudaEvent, cudaStream2));
689
690
690
691
// copy src1 to device
@@ -697,7 +698,7 @@ static void ggml_cuda_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor
697
698
dmmv (c_Q, c_Y, c_D, ne00, ne01, cudaStream);
698
699
CUDA_CHECK (cudaGetLastError ());
699
700
700
- } else {
701
+ } else { // general dequantization kernel + cuBLAS matrix matrix multiplication
701
702
float * c_X = d_X + i * x_ne;
702
703
703
704
// convert src0 to fp32 on device
@@ -728,7 +729,7 @@ static void ggml_cuda_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor
728
729
}
729
730
730
731
CUDA_CHECK (cudaDeviceSynchronize ());
731
- if (ne11 > 1 ) {
732
+ if (!mul_mat_vec ) {
732
733
ggml_cuda_pool_free (d_X, x_size);
733
734
}
734
735
ggml_cuda_pool_free (d_Y, y_size);
0 commit comments