Skip to content

Commit 5a0ecf7

Browse files
More readable dequantize_mul_mat_vec logic
1 parent 9da44fd commit 5a0ecf7

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

ggml-cuda.cu

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -643,6 +643,7 @@ static void ggml_cuda_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor
643643
const int nb2 = dst->nb[2];
644644
const int nb3 = dst->nb[3];
645645
const ggml_type type = src0->type;
646+
const bool mul_mat_vec = ne11 == 1;
646647

647648
const float alpha = 1.0f;
648649
const float beta = 0.0f;
@@ -654,7 +655,7 @@ static void ggml_cuda_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor
654655

655656
size_t x_size, y_size, d_size, q_size;
656657
float * d_X;
657-
if (ne11 > 1) {
658+
if (!mul_mat_vec) {
658659
d_X = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * x_ne, &x_size);
659660
}
660661
float * d_Y = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * y_ne, &y_size);
@@ -684,7 +685,7 @@ static void ggml_cuda_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor
684685
} else {
685686
GGML_ASSERT(false);
686687
}
687-
if (ne11 == 1) {
688+
if (mul_mat_vec) { // specialized dequantize_mul_mat_vec kernel
688689
CUDA_CHECK(cudaEventRecord(cudaEvent, cudaStream2));
689690

690691
// copy src1 to device
@@ -697,7 +698,7 @@ static void ggml_cuda_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor
697698
dmmv(c_Q, c_Y, c_D, ne00, ne01, cudaStream);
698699
CUDA_CHECK(cudaGetLastError());
699700

700-
} else {
701+
} else { // general dequantization kernel + cuBLAS matrix matrix multiplication
701702
float * c_X = d_X + i * x_ne;
702703

703704
// convert src0 to fp32 on device
@@ -728,7 +729,7 @@ static void ggml_cuda_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor
728729
}
729730

730731
CUDA_CHECK(cudaDeviceSynchronize());
731-
if (ne11 > 1) {
732+
if (!mul_mat_vec) {
732733
ggml_cuda_pool_free(d_X, x_size);
733734
}
734735
ggml_cuda_pool_free(d_Y, y_size);

0 commit comments

Comments
 (0)