Skip to content

Some more Q4_K and Q5_K speedup on CUDA #2346

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jul 23, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 84 additions & 30 deletions ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1073,10 +1073,12 @@ static __global__ void dequantize_mul_mat_vec_q5_k(const void * __restrict__ vx,
uint16_t aux[4];
const uint8_t * sc = (const uint8_t *)aux;

uint16_t q16[8];
const uint8_t * q4 = (const uint8_t *)q16;

for (int i = ix; i < num_blocks_per_row; i += 2) {

const uint8_t * ql1 = x[i].qs + q_offset;
const uint8_t * ql2 = ql1 + 64;
const uint8_t * qh = x[i].qh + l0;
const float * y1 = yy + i*QK_K + y_offset;
const float * y2 = y1 + 128;
Expand All @@ -1092,15 +1094,25 @@ static __global__ void dequantize_mul_mat_vec_q5_k(const void * __restrict__ vx,

float4 sum = {0.f, 0.f, 0.f, 0.f};
float smin = 0;
const uint16_t * q1 = (const uint16_t *)ql1;
const uint16_t * q2 = q1 + 32;
q16[0] = q1[0] & 0x0f0f;
q16[1] = q1[8] & 0x0f0f;
q16[2] = (q1[0] >> 4) & 0x0f0f;
q16[3] = (q1[8] >> 4) & 0x0f0f;
q16[4] = q2[0] & 0x0f0f;
q16[5] = q2[8] & 0x0f0f;
q16[6] = (q2[0] >> 4) & 0x0f0f;
q16[7] = (q2[8] >> 4) & 0x0f0f;
for (int l = 0; l < n; ++l) {
sum.x += y1[l+ 0] * ((ql1[l+ 0] & 0xF) + (qh[l+ 0] & (hm1 << 0) ? 16 : 0))
+ y1[l+16] * ((ql1[l+16] & 0xF) + (qh[l+16] & (hm1 << 0) ? 16 : 0));
sum.y += y1[l+32] * ((ql1[l+ 0] >> 4) + (qh[l+ 0] & (hm1 << 1) ? 16 : 0))
+ y1[l+48] * ((ql1[l+16] >> 4) + (qh[l+16] & (hm1 << 1) ? 16 : 0));
sum.z += y2[l+ 0] * ((ql2[l+ 0] & 0xF) + (qh[l+ 0] & (hm2 << 0) ? 16 : 0))
+ y2[l+16] * ((ql2[l+16] & 0xF) + (qh[l+16] & (hm2 << 0) ? 16 : 0));
sum.w += y2[l+32] * ((ql2[l+ 0] >> 4) + (qh[l+ 0] & (hm2 << 1) ? 16 : 0))
+ y2[l+48] * ((ql2[l+16] >> 4) + (qh[l+16] & (hm2 << 1) ? 16 : 0));
sum.x += y1[l+ 0] * (q4[l +0] + (qh[l+ 0] & (hm1 << 0) ? 16 : 0))
+ y1[l+16] * (q4[l +2] + (qh[l+16] & (hm1 << 0) ? 16 : 0));
sum.y += y1[l+32] * (q4[l +4] + (qh[l+ 0] & (hm1 << 1) ? 16 : 0))
+ y1[l+48] * (q4[l +6] + (qh[l+16] & (hm1 << 1) ? 16 : 0));
sum.z += y2[l+ 0] * (q4[l +8] + (qh[l+ 0] & (hm2 << 0) ? 16 : 0))
+ y2[l+16] * (q4[l+10] + (qh[l+16] & (hm2 << 0) ? 16 : 0));
sum.w += y2[l+32] * (q4[l+12] + (qh[l+ 0] & (hm2 << 1) ? 16 : 0))
+ y2[l+48] * (q4[l+14] + (qh[l+16] & (hm2 << 1) ? 16 : 0));
smin += (y1[l] + y1[l+16]) * sc[2] + (y1[l+32] + y1[l+48]) * sc[3]
+ (y2[l] + y2[l+16]) * sc[6] + (y2[l+32] + y2[l+48]) * sc[7];
}
Expand Down Expand Up @@ -1554,15 +1566,23 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1(
#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
const block_q4_K * bq4_K = (const block_q4_K *) vbq;

const int bq8_offset = QR4_K * (iqs / QI8_1); // 0, 2, 4, 6
// iqs is in 0...15. bq8_offset = 2 * (iqs/4) -> bq8_offset = 0, 2, 4, 6
const int bq8_offset = QR4_K * (iqs / (QI8_1/2));

float sumf_d = 0.0f;
float sumf_m = 0.0f;

const float d = bq4_K->d;
const float dmin = bq4_K->dmin;

const int v = *((int *) &bq4_K->qs[sizeof(int) * iqs]);
// iqs = 0....3 -> bq8_offset = 0, want q4_offset = 0, 4, 8, 12
// iqs = 4....7 -> bq8_offset = 2, want q4_offset = 32, 36, 40, 44
// iqs = 8...11 -> bq8_offset = 4, want q4_offset = 64, 68, 72, 76
// iqs = 12..15 -> bq8_offset = 6, want q4_offset = 96, 100, 104, 108

const int * q4 = (const int *)(bq4_K->qs + 16 * bq8_offset + 4 * (iqs%4));
const int v1 = q4[0];
const int v2 = q4[4];

const uint16_t * scales = (const uint16_t *)bq4_K->scales;
uint16_t aux[2];
Expand All @@ -1580,13 +1600,19 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1(
for (int i = 0; i < QR4_K; ++i) {

const block_q8_1 * bq8i = bq8_1 + bq8_offset + i;
const int ui = *((int*) &bq8i->qs[sizeof(int) * (iqs % QI8_1)]);
const float d8i = bq8i->d;
const int * q8 = (const int *)bq8i->qs + (iqs%4);
const int ui1 = q8[0];
const int ui2 = q8[4];

const int vi = (v >> (4*i)) & 0x0F0F0F0F;
const int vi1 = (v1 >> (4*i)) & 0x0F0F0F0F;
const int vi2 = (v2 >> (4*i)) & 0x0F0F0F0F;

sumf_d += d8i * (__dp4a(vi, ui, 0) * sc[i]); // SIMD dot product
sumf_m += d8i * (__dp4a(0x01010101, ui, 0) * m[i]); // multiply constant part of q4_K with sum of q8_1 values
const int dot1 = __dp4a(vi2, ui2, __dp4a(vi1, ui1, 0)); // SIMD dot product
const int dot2 = __dp4a(0x01010101, ui2, __dp4a(0x01010101, ui1, 0));

sumf_d += d8i * (dot1 * sc[i]);
sumf_m += d8i * (dot2 * m[i]); // multiply constant part of q4_K with sum of q8_1 values
}

return d*sumf_d - dmin*sumf_m;
Expand All @@ -1601,36 +1627,58 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1(
#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
const block_q5_K * bq5_K = (const block_q5_K *) vbq;

const int bq8_offset = QR5_K * (iqs / QI8_1);
const int bq8_offset = QR5_K * (iqs / (QI8_1/2));
const int * ql = (const int *)(bq5_K->qs + 16 * bq8_offset + 4 * (iqs%4));
const int * qh = (const int *)(bq5_K->qh + 4 * (iqs%4));

float sumf_d = 0.0f;
float sumf_m = 0.0f;

const float d = bq5_K->d;
const float dmin = bq5_K->dmin;

const int vl = *((int *) &bq5_K->qs[sizeof(int) * iqs]);
const int vl1 = ql[0];
const int vl2 = ql[4];

const int vh = (*((int *) &bq5_K->qh[sizeof(int) * (iqs % (QI5_K/4))])) >> bq8_offset;
const int vh1 = qh[0] >> bq8_offset;
const int vh2 = qh[4] >> bq8_offset;

for (int i = 0; i < QR5_K; ++i) {
const int isc = bq8_offset + i;
const uint16_t * scales = (const uint16_t *)bq5_K->scales;
uint16_t aux[2];
const int j = bq8_offset/2;
if (j < 2) {
aux[0] = scales[j+0] & 0x3f3f;
aux[1] = scales[j+2] & 0x3f3f;
} else {
aux[0] = ((scales[j+2] >> 0) & 0x0f0f) | ((scales[j-2] & 0xc0c0) >> 2);
aux[1] = ((scales[j+2] >> 4) & 0x0f0f) | ((scales[j-0] & 0xc0c0) >> 2);
}
const uint8_t * sc = (const uint8_t *)aux;
const uint8_t * m = sc + 2;

uint8_t sc, m;
get_scale_min_k4(isc, bq5_K->scales, sc, m);
for (int i = 0; i < QR5_K; ++i) {

const block_q8_1 * bq8i = bq8_1 + bq8_offset + i;
const int ui = *((int*) &bq8i->qs[sizeof(int) * (iqs % QI8_1)]);
const float d8i = bq8i->d;
const int * q8 = (const int *)bq8i->qs + (iqs%4);
const int ui1 = q8[0];
const int ui2 = q8[4];

const int vil = (vl >> (4*i)) & 0x0F0F0F0F;
const int vil1 = (vl1 >> (4*i)) & 0x0F0F0F0F;
const int vil2 = (vl2 >> (4*i)) & 0x0F0F0F0F;

const int vih1 = ((vh1 >> i) << 4) & 0x10101010;
const int vih2 = ((vh2 >> i) << 4) & 0x10101010;

const int vi1 = vil1 | vih1;
const int vi2 = vil2 | vih2;

const int vih = ((vh >> i) << 4) & 0x10101010;
const int dot1 = __dp4a(vi2, ui2, __dp4a(vi1, ui1, 0)); // SIMD dot product
const int dot2 = __dp4a(0x01010101, ui2, __dp4a(0x01010101, ui1, 0));

const int vi = vil | vih;
sumf_d += d8i * (dot1 * sc[i]);
sumf_m += d8i * (dot2 * m[i]);

sumf_d += d8i * (__dp4a(vi, ui, 0) * sc); // SIMD dot product
sumf_m += d8i * (__dp4a(0x01010101, ui, 0) * m); // multiply constant part of q5_K with sum of q8_1 values
}

return d*sumf_d - dmin*sumf_m;
Expand Down Expand Up @@ -2301,7 +2349,10 @@ static void mul_mat_vec_q4_K_q8_1_cuda(const void * vx, const void * vy, float *
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(1, block_num_y, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
mul_mat_vec_q<QK_K, QI4_K, block_q4_K, vec_dot_q4_K_q8_1>
// Note: we use QI4_K/2 instead of QI4_K to make the dot product template require 4 groups of quants to be processed per
// kernel call instead of 2. This results in a better perfmance because the cost of computing the k-quant scales
// is better amortized.
mul_mat_vec_q<QK_K, QI4_K/2, block_q4_K, vec_dot_q4_K_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
}

Expand All @@ -2310,7 +2361,10 @@ static void mul_mat_vec_q5_K_q8_1_cuda(const void * vx, const void * vy, float *
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(1, block_num_y, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
mul_mat_vec_q<QK_K, QI5_K, block_q5_K, vec_dot_q5_K_q8_1>
// Note: we use QI5_K/2 instead of QI5_K to make the dot product template require 4 groups of quants to be processed per
// kernel call instead of 2. This results in a better perfmance because the cost of computing the k-quant scales
// is better amortized.
mul_mat_vec_q<QK_K, QI5_K/2, block_q5_K, vec_dot_q5_K_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
}

Expand Down