Skip to content

Commit 90b8539

Browse files
q8_1 quantization
1 parent 7ce23b2 commit 90b8539

File tree

1 file changed

+49
-41
lines changed

1 file changed

+49
-41
lines changed

ggml-cuda.cu

Lines changed: 49 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -70,9 +70,11 @@ typedef void (*ggml_cuda_op_t)(
7070

7171
// QK = number of values after dequantization
7272
// QR = QK / number of values before dequantization
73+
// QI = number of 32 bit integers before dequantization
7374

7475
#define QK4_0 32
7576
#define QR4_0 2
77+
#define QI4_0 4
7678
typedef struct {
7779
half d; // delta
7880
uint8_t qs[QK4_0 / 2]; // nibbles / quants
@@ -81,6 +83,7 @@ static_assert(sizeof(block_q4_0) == sizeof(ggml_fp16_t) + QK4_0 / 2, "wrong q4_0
8183

8284
#define QK4_1 32
8385
#define QR4_1 2
86+
#define QI4_1 4
8487
typedef struct {
8588
half d; // delta
8689
half m; // min
@@ -115,7 +118,16 @@ typedef struct {
115118
} block_q8_0;
116119
static_assert(sizeof(block_q8_0) == sizeof(ggml_fp16_t) + QK8_0, "wrong q8_0 block size/padding");
117120

118-
typedef float (*vec_dot_q_cuda_t)(const void * vbq, const block_q8_0 * bq8_0, const int iqs);
121+
#define QK8_1 32
122+
#define QR8_1 1
123+
typedef struct {
124+
half d; // delta
125+
half s; // unquantized sum
126+
int8_t qs[QK8_0]; // quants
127+
} block_q8_1;
128+
static_assert(sizeof(block_q8_1) == 2*sizeof(ggml_fp16_t) + QK8_0, "wrong q8_1 block size/padding");
129+
130+
typedef float (*vec_dot_q_cuda_t)(const void * vbq, const block_q8_1 * bq8_1, const int iqs);
119131

120132
//================================= k-quants
121133

@@ -1155,25 +1167,27 @@ static __device__ void convert_f16(const void * vx, const int ib, const int iqs,
11551167
v.y = x[ib + iqs + 1];
11561168
}
11571169

1158-
static __global__ void quantize_q8_0(const float * x, void * vy, const int k) {
1170+
static __global__ void quantize_q8_1(const float * x, void * vy, const int k) {
11591171
const int i = blockDim.x*blockIdx.x + threadIdx.x;
11601172

11611173
if (i >= k) {
11621174
return;
11631175
}
11641176

1165-
block_q8_0 * y = (block_q8_0 *) vy;
1177+
block_q8_1 * y = (block_q8_1 *) vy;
11661178

11671179
const int ib = i / QK8_0; // block index
11681180
const int iqs = i % QK8_0; // quant index
11691181

11701182
const float xi = x[i];
11711183
float amax = fabsf(xi);
1184+
float sum = xi;
11721185

11731186
__syncwarp();
11741187
#pragma unroll
11751188
for (int mask = 16; mask > 0; mask >>= 1) {
11761189
amax = fmaxf(amax, __shfl_xor_sync(0xffffffff, amax, mask, 32));
1190+
sum += __shfl_xor_sync(0xffffffff, sum, mask, 32);
11771191
}
11781192

11791193
const float d = amax / 127;
@@ -1186,51 +1200,47 @@ static __global__ void quantize_q8_0(const float * x, void * vy, const int k) {
11861200
}
11871201

11881202
y[ib].d = d;
1203+
y[ib].s = sum;
11891204
}
11901205

1191-
static __device__ float vec_dot_q4_0_q8_0(const void * vbq, const block_q8_0 * bq8_0, const int iqs) {
1206+
static __device__ __forceinline__ float vec_dot_q4_0_q8_1(const void * vbq, const block_q8_1 * bq8_1, const int iqs) {
11921207
const block_q4_0 * bq4_0 = (const block_q4_0 *) vbq;
11931208

11941209
int vi;
1195-
int ui0, ui1;
11961210
memcpy(&vi, &bq4_0->qs[sizeof(int) * (iqs + 0)], sizeof(int));
1197-
memcpy(&ui0, &bq8_0->qs[sizeof(int) * (iqs + 0)], sizeof(int));
1198-
memcpy(&ui1, &bq8_0->qs[sizeof(int) * (iqs + 4)], sizeof(int));
1211+
const int ui0 = *((int *) &bq8_1->qs[sizeof(int) * (iqs + 0)]);
1212+
const int ui1 = *((int *) &bq8_1->qs[sizeof(int) * (iqs + QI4_0)]);
11991213

1200-
const float d = bq4_0->d * bq8_0->d;
1214+
const float d = bq4_0->d * bq8_1->d;
12011215

12021216
const int vi0 = __vsub4((vi >> 0) & 0x0F0F0F0F, 0x08080808);
12031217
const int vi1 = __vsub4((vi >> 4) & 0x0F0F0F0F, 0x08080808);
12041218

1205-
const int sumi0 = __dp4a(vi0, ui0, 0);
1206-
const int sumi1 = __dp4a(vi1, ui1, 0);
1219+
int sumi = __dp4a(vi0, ui0, 0);
1220+
sumi = __dp4a(vi1, ui1, sumi);
12071221

1208-
return (sumi0 + sumi1)*d;
1222+
return sumi*d;
12091223

12101224
}
12111225

1212-
static __device__ float vec_dot_q4_1_q8_0(const void * vbq, const block_q8_0 * bq8_0, const int iqs) {
1226+
static __device__ __forceinline__ float vec_dot_q4_1_q8_1(const void * vbq, const block_q8_1 * bq8_1, const int iqs) {
12131227
const block_q4_1 * bq4_1 = (const block_q4_1 *) vbq;
12141228

1215-
int vi = *((int *) &bq4_1->qs[sizeof(int) * (iqs + 0)]);
1216-
int ui0, ui1;
1217-
memcpy(&ui0, &bq8_0->qs[sizeof(int) * (iqs + 0)], sizeof(int));
1218-
memcpy(&ui1, &bq8_0->qs[sizeof(int) * (iqs + 4)], sizeof(int));
1229+
const int vi = *((int *) &bq4_1->qs[sizeof(int) * (iqs + 0)]);
1230+
const int ui0 = *((int *) &bq8_1->qs[sizeof(int) * (iqs + 0)]);
1231+
const int ui1 = *((int *) &bq8_1->qs[sizeof(int) * (iqs + QI4_1)]);
12191232

1220-
const float d4_1 = bq4_1->d;
1221-
const float m = bq4_1->m;
1222-
const float d8_0 = bq8_0->d;
1233+
const float d = bq4_1->d * bq8_1->d;
1234+
const float m = bq4_1->m;
1235+
const float s = bq8_1->s;
12231236

12241237
const int vi0 = (vi >> 0) & 0x0F0F0F0F;
12251238
const int vi1 = (vi >> 4) & 0x0F0F0F0F;
12261239

1227-
const int sumi0 = __dp4a(vi0, ui0, 0);
1228-
const int sumi1 = __dp4a(vi1, ui1, 0);
1229-
1230-
const int sumi2 = __dp4a(0x01010101, ui0, 0);
1231-
const int sumi3 = __dp4a(0x01010101, ui1, 0);
1240+
int sumi = __dp4a(vi0, ui0, 0);
1241+
sumi = __dp4a(vi1, ui1, sumi);
12321242

1233-
return (sumi0 + sumi1)*d4_1*d8_0 + (sumi2 + sumi3)*m*d8_0;
1243+
return sumi*d + m*s / QI4_1;
12341244

12351245
}
12361246

@@ -1263,8 +1273,6 @@ static __global__ void mul_mat_vec_q(const void * vx, const void * vy, float * d
12631273
return;
12641274
}
12651275

1266-
const int tid = threadIdx.x;
1267-
12681276
const int blocks_per_row = ncols / qk;
12691277
const int blocks_per_warp = WARP_SIZE * sizeof(int)*2/qk;
12701278
const int ints_per_block = qk / (2 * sizeof(int));
@@ -1273,14 +1281,14 @@ static __global__ void mul_mat_vec_q(const void * vx, const void * vy, float * d
12731281
float tmp = 0.0f;
12741282

12751283
const block_q_t * x = (const block_q_t *) vx;
1276-
const block_q8_0 * y = (const block_q8_0 *) vy;
1284+
const block_q8_1 * y = (const block_q8_1 *) vy;
12771285

12781286
for (int i = 0; i < blocks_per_row; i += blocks_per_warp) {
1279-
const int ibx = row*blocks_per_row + i + tid/ints_per_block; // x block index
1287+
const int ibx = row*blocks_per_row + i + threadIdx.x/ints_per_block; // x block index
12801288

1281-
const int iby = i + tid/ints_per_block;
1289+
const int iby = i + threadIdx.x/ints_per_block;
12821290

1283-
const int iqs = tid % ints_per_block;
1291+
const int iqs = threadIdx.x % ints_per_block;
12841292

12851293
tmp += vec_dot_q_cuda(&x[ibx], &y[iby], iqs);
12861294
}
@@ -1292,7 +1300,7 @@ static __global__ void mul_mat_vec_q(const void * vx, const void * vy, float * d
12921300
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
12931301
}
12941302

1295-
if (tid == 0) {
1303+
if (threadIdx.x == 0) {
12961304
dst[row] = tmp;
12971305
}
12981306
}
@@ -1612,9 +1620,9 @@ static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, con
16121620
rms_norm_f32<<<nrows, block_dims, 0, stream>>>(x, dst, ncols);
16131621
}
16141622

1615-
static void quantize_row_q8_0_cuda(const float * x, void * vy, const int k, cudaStream_t stream) {
1623+
static void quantize_row_q8_1_cuda(const float * x, void * vy, const int k, cudaStream_t stream) {
16161624
const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
1617-
quantize_q8_0<<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(x, vy, k);
1625+
quantize_q8_1<<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(x, vy, k);
16181626
}
16191627

16201628
static void dequantize_row_q4_0_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
@@ -1770,21 +1778,21 @@ static void dequantize_mul_mat_vec_q6_K_cuda(const void * vx, const float * y, f
17701778
dequantize_mul_mat_vec_q6_k<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
17711779
}
17721780

1773-
static void mul_mat_vec_q4_0_q8_0_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
1781+
static void mul_mat_vec_q4_0_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
17741782
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
17751783
const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y;
17761784
const dim3 block_nums(1, block_num_y, 1);
17771785
const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1);
1778-
mul_mat_vec_q<QK4_0, block_q4_0, vec_dot_q4_0_q8_0>
1786+
mul_mat_vec_q<QK4_0, block_q4_0, vec_dot_q4_0_q8_1>
17791787
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
17801788
}
17811789

1782-
static void mul_mat_vec_q4_1_q8_0_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
1790+
static void mul_mat_vec_q4_1_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
17831791
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
17841792
const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y;
17851793
const dim3 block_nums(1, block_num_y, 1);
17861794
const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1);
1787-
mul_mat_vec_q<QK4_0, block_q4_1, vec_dot_q4_1_q8_0>
1795+
mul_mat_vec_q<QK4_0, block_q4_1, vec_dot_q4_1_q8_1>
17881796
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
17891797
}
17901798

@@ -2302,14 +2310,14 @@ inline void ggml_cuda_op_mul_mat_vec_q(
23022310

23032311
size_t as;
23042312
void * src1_q8_0 = ggml_cuda_pool_malloc(ne00*sizeof(block_q8_0)/QK8_0, &as);
2305-
quantize_row_q8_0_cuda(src1_ddf_i, src1_q8_0, ne00, cudaStream_main);
2313+
quantize_row_q8_1_cuda(src1_ddf_i, src1_q8_0, ne00, cudaStream_main);
23062314

23072315
switch (src0->type) {
23082316
case GGML_TYPE_Q4_0:
2309-
mul_mat_vec_q4_0_q8_0_cuda(src0_ddq_i, src1_q8_0, dst_ddf_i, ne00, nrows, cudaStream_main);
2317+
mul_mat_vec_q4_0_q8_1_cuda(src0_ddq_i, src1_q8_0, dst_ddf_i, ne00, nrows, cudaStream_main);
23102318
break;
23112319
case GGML_TYPE_Q4_1:
2312-
mul_mat_vec_q4_1_q8_0_cuda(src0_ddq_i, src1_q8_0, dst_ddf_i, ne00, nrows, cudaStream_main);
2320+
mul_mat_vec_q4_1_q8_1_cuda(src0_ddq_i, src1_q8_0, dst_ddf_i, ne00, nrows, cudaStream_main);
23132321
break;
23142322
default:
23152323
GGML_ASSERT(false);

0 commit comments

Comments
 (0)