Skip to content

Faster k-quants on older GPUs #1930

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jun 19, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 50 additions & 33 deletions ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -483,15 +483,15 @@ static __global__ void dequantize_mul_mat_vec_q2_k(const void * vx, const float

const block_q2_K * x = (const block_q2_K *)vx + ib0;

const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31
const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0
const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...15
const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0,1

const int step = 16/K_QUANTS_PER_ITERATION;

const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128...
const int in = tid - step*im; // 0...7
const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128...
const int in = tid - step*im; // 0...15 or 0...7

const int l0 = K_QUANTS_PER_ITERATION*in; // 0...14 in steps of 4
const int l0 = K_QUANTS_PER_ITERATION*in; // 0...15 or 0...14 in steps of 2
const int q_offset = 32*im + l0;
const int s_offset = 8*im;
const int y_offset = 128*im + l0;
Expand Down Expand Up @@ -546,27 +546,30 @@ static __global__ void dequantize_mul_mat_vec_q2_k(const void * vx, const float
}
}

static __global__ void dequantize_mul_mat_vec_q3_k(const void * vx, const float * yy, float * dst, const int ncols) {
static __global__ void dequantize_mul_mat_vec_q3_k(const void * vx, const float * yy, float * dst, const int ncols, int nrows) {

const uint16_t kmask1 = 0x0303;
const uint16_t kmask2 = 0x0f0f;

const int row = blockIdx.x;
const int row = blockIdx.y*blockDim.y + threadIdx.y;
if (row > nrows) return;

const int num_blocks_per_row = ncols / QK_K;
const int ib0 = row*num_blocks_per_row;

const block_q3_K * x = (const block_q3_K *)vx + ib0;

const int tid = threadIdx.x/2; // 0...15
const int ix = threadIdx.x%2; // 0, 1
const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16
const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0,1

const int n = 2; // iterations in the inner loop
const int im = tid/8; // 0 or 1. 0 computes 0..., 1 computes 128...
const int in = tid - 8*im; // 0...7
const int n = K_QUANTS_PER_ITERATION; // iterations in the inner loop
const int step = 16/K_QUANTS_PER_ITERATION;
const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128...
const int in = tid - step*im; // 0....15 or 0...7

const uint8_t m = 1 << (4*im);

const int l0 = n*in; // 0...28 in steps of 4
const int l0 = n*in; // 0...15 or 0...14 in steps of 2
const int q_offset = 32*im + l0;
const int y_offset = 128*im + l0;

Expand All @@ -577,7 +580,7 @@ static __global__ void dequantize_mul_mat_vec_q3_k(const void * vx, const float

float tmp = 0; // partial sum for thread in warp

for (int i = ix; i < num_blocks_per_row; i += 2) {
for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {

const float * y = yy + i * QK_K + y_offset;
const uint8_t * q = x[i].qs + q_offset;
Expand Down Expand Up @@ -618,22 +621,25 @@ static __global__ void dequantize_mul_mat_vec_q3_k(const void * vx, const float
}
}

static __global__ void dequantize_mul_mat_vec_q4_k(const void * vx, const float * yy, float * dst, const int ncols) {
static __global__ void dequantize_mul_mat_vec_q4_k(const void * vx, const float * yy, float * dst, const int ncols, int nrows) {

const uint16_t kmask1 = 0x3f3f;
const uint16_t kmask2 = 0x0f0f;
const uint16_t kmask3 = 0xc0c0;

const int row = blockIdx.x;
const int row = blockIdx.y*blockDim.y + threadIdx.y;
if (row > nrows) return;
const int num_blocks_per_row = ncols / QK_K;
const int ib0 = row*num_blocks_per_row;

const int tid = threadIdx.x/2; // 0...15
const int ix = threadIdx.x%2;
const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16
const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0,1

const int il = tid/4; // 0...3
const int ir = tid - 4*il;// 0...3
const int n = 4;
const int step = 8/K_QUANTS_PER_ITERATION; // 8 or 4

const int il = tid/step; // 0...3
const int ir = tid - step*il; // 0...7 or 0...3
const int n = 2 * K_QUANTS_PER_ITERATION; // 2 or 4

const int im = il/2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
const int in = il%2;
Expand All @@ -649,7 +655,7 @@ static __global__ void dequantize_mul_mat_vec_q4_k(const void * vx, const float

float tmp = 0; // partial sum for thread in warp

for (int i = ix; i < num_blocks_per_row; i += 2) {
for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {

const uint8_t * q1 = x[i].qs + q_offset;
const uint8_t * q2 = q1 + 64;
Expand Down Expand Up @@ -704,7 +710,7 @@ static __global__ void dequantize_mul_mat_vec_q5_k(const void * vx, const float

const int il = tid/4; // 0...3
const int ir = tid - 4*il;// 0...3
const int n = 4;
const int n = 2;

const int im = il/2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
const int in = il%2;
Expand Down Expand Up @@ -743,11 +749,16 @@ static __global__ void dequantize_mul_mat_vec_q5_k(const void * vx, const float
float4 sum = {0.f, 0.f, 0.f, 0.f};
float smin = 0;
for (int l = 0; l < n; ++l) {
sum.x += y1[l+ 0] * ((ql1[l] & 0xF) + (qh[l] & (hm1 << 0) ? 16 : 0));
sum.y += y1[l+32] * ((ql1[l] >> 4) + (qh[l] & (hm1 << 1) ? 16 : 0));
sum.z += y2[l+ 0] * ((ql2[l] & 0xF) + (qh[l] & (hm2 << 0) ? 16 : 0));
sum.w += y2[l+32] * ((ql2[l] >> 4) + (qh[l] & (hm2 << 1) ? 16 : 0));
smin += y1[l] * sc[2] + y1[l+32] * sc[3] + y2[l] * sc[6] + y2[l+32] * sc[7];
sum.x += y1[l+ 0] * ((ql1[l+ 0] & 0xF) + (qh[l+ 0] & (hm1 << 0) ? 16 : 0))
+ y1[l+16] * ((ql1[l+16] & 0xF) + (qh[l+16] & (hm1 << 0) ? 16 : 0));
sum.y += y1[l+32] * ((ql1[l+ 0] >> 4) + (qh[l+ 0] & (hm1 << 1) ? 16 : 0))
+ y1[l+48] * ((ql1[l+16] >> 4) + (qh[l+16] & (hm1 << 1) ? 16 : 0));
sum.z += y2[l+ 0] * ((ql2[l+ 0] & 0xF) + (qh[l+ 0] & (hm2 << 0) ? 16 : 0))
+ y2[l+16] * ((ql2[l+16] & 0xF) + (qh[l+16] & (hm2 << 0) ? 16 : 0));
sum.w += y2[l+32] * ((ql2[l+ 0] >> 4) + (qh[l+ 0] & (hm2 << 1) ? 16 : 0))
+ y2[l+48] * ((ql2[l+16] >> 4) + (qh[l+16] & (hm2 << 1) ? 16 : 0));
smin += (y1[l] + y1[l+16]) * sc[2] + (y1[l+32] + y1[l+48]) * sc[3]
+ (y2[l] + y2[l+16]) * sc[6] + (y2[l+32] + y2[l+48]) * sc[7];
}
tmp += dall * (sum.x * sc[0] + sum.y * sc[1] + sum.z * sc[4] + sum.w * sc[5]) - dmin * smin;

Expand Down Expand Up @@ -1260,7 +1271,7 @@ static void dequantize_mul_mat_vec_q8_0_cuda(const void * vx, const float * y, f

static void dequantize_mul_mat_vec_q2_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
GGML_ASSERT(ncols % QK_K == 0);
const int ny = 2;
const int ny = 2; // very slightly faster than 1 even when K_QUANTS_PER_ITERATION = 2
const int block_num_y = (nrows + ny - 1) / ny;
const dim3 block_nums(1, block_num_y, 1);
const dim3 block_dims(32, ny, 1);
Expand All @@ -1269,14 +1280,20 @@ static void dequantize_mul_mat_vec_q2_K_cuda(const void * vx, const float * y, f

static void dequantize_mul_mat_vec_q3_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
GGML_ASSERT(ncols % QK_K == 0);
const dim3 block_dims(32, 1, 1);
dequantize_mul_mat_vec_q3_k<<<nrows, block_dims, 0, stream>>>(vx, y, dst, ncols);
const int ny = 2 / K_QUANTS_PER_ITERATION;
const int block_num_y = (nrows + ny - 1) / ny;
const dim3 block_nums(1, block_num_y, 1);
const dim3 block_dims(32, ny, 1);
dequantize_mul_mat_vec_q3_k<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
}

static void dequantize_mul_mat_vec_q4_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
GGML_ASSERT(ncols % QK_K == 0);
const dim3 block_dims(32, 1, 1);
dequantize_mul_mat_vec_q4_k<<<nrows, block_dims, 0, stream>>>(vx, y, dst, ncols);
const int ny = 2 / K_QUANTS_PER_ITERATION;
const int block_num_y = (nrows + ny - 1) / ny;
const dim3 block_nums(1, block_num_y, 1);
const dim3 block_dims(32, ny, 1);
dequantize_mul_mat_vec_q4_k<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
}

static void dequantize_mul_mat_vec_q5_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
Expand Down