@@ -70,9 +70,11 @@ typedef void (*ggml_cuda_op_t)(
70
70
71
71
// QK = number of values after dequantization
72
72
// QR = QK / number of values before dequantization
73
+ // QI = number of 32 bit integers before dequantization
73
74
74
75
#define QK4_0 32
75
76
#define QR4_0 2
77
+ #define QI4_0 4
76
78
typedef struct {
77
79
half d; // delta
78
80
uint8_t qs[QK4_0 / 2 ]; // nibbles / quants
@@ -81,6 +83,7 @@ static_assert(sizeof(block_q4_0) == sizeof(ggml_fp16_t) + QK4_0 / 2, "wrong q4_0
81
83
82
84
#define QK4_1 32
83
85
#define QR4_1 2
86
+ #define QI4_1 4
84
87
typedef struct {
85
88
half d; // delta
86
89
half m; // min
@@ -115,7 +118,16 @@ typedef struct {
115
118
} block_q8_0;
116
119
static_assert (sizeof (block_q8_0) == sizeof(ggml_fp16_t ) + QK8_0, "wrong q8_0 block size/padding");
117
120
118
- typedef float (*vec_dot_q_cuda_t )(const void * vbq, const block_q8_0 * bq8_0, const int iqs);
121
+ #define QK8_1 32
122
+ #define QR8_1 1
123
+ typedef struct {
124
+ half d; // delta
125
+ half s; // unquantized sum
126
+ int8_t qs[QK8_0]; // quants
127
+ } block_q8_1;
128
+ static_assert (sizeof (block_q8_1) == 2*sizeof(ggml_fp16_t ) + QK8_0, "wrong q8_1 block size/padding");
129
+
130
+ typedef float (*vec_dot_q_cuda_t )(const void * vbq, const block_q8_1 * bq8_1, const int iqs);
119
131
120
132
// ================================= k-quants
121
133
@@ -1155,25 +1167,27 @@ static __device__ void convert_f16(const void * vx, const int ib, const int iqs,
1155
1167
v.y = x[ib + iqs + 1 ];
1156
1168
}
1157
1169
1158
- static __global__ void quantize_q8_0 (const float * x, void * vy, const int k) {
1170
+ static __global__ void quantize_q8_1 (const float * x, void * vy, const int k) {
1159
1171
const int i = blockDim .x *blockIdx .x + threadIdx .x ;
1160
1172
1161
1173
if (i >= k) {
1162
1174
return ;
1163
1175
}
1164
1176
1165
- block_q8_0 * y = (block_q8_0 *) vy;
1177
+ block_q8_1 * y = (block_q8_1 *) vy;
1166
1178
1167
1179
const int ib = i / QK8_0; // block index
1168
1180
const int iqs = i % QK8_0; // quant index
1169
1181
1170
1182
const float xi = x[i];
1171
1183
float amax = fabsf (xi);
1184
+ float sum = xi;
1172
1185
1173
1186
__syncwarp ();
1174
1187
#pragma unroll
1175
1188
for (int mask = 16 ; mask > 0 ; mask >>= 1 ) {
1176
1189
amax = fmaxf (amax, __shfl_xor_sync (0xffffffff , amax, mask, 32 ));
1190
+ sum += __shfl_xor_sync (0xffffffff , sum, mask, 32 );
1177
1191
}
1178
1192
1179
1193
const float d = amax / 127 ;
@@ -1186,51 +1200,47 @@ static __global__ void quantize_q8_0(const float * x, void * vy, const int k) {
1186
1200
}
1187
1201
1188
1202
y[ib].d = d;
1203
+ y[ib].s = sum;
1189
1204
}
1190
1205
1191
- static __device__ float vec_dot_q4_0_q8_0 (const void * vbq, const block_q8_0 * bq8_0 , const int iqs) {
1206
+ static __device__ __forceinline__ float vec_dot_q4_0_q8_1 (const void * vbq, const block_q8_1 * bq8_1 , const int iqs) {
1192
1207
const block_q4_0 * bq4_0 = (const block_q4_0 *) vbq;
1193
1208
1194
1209
int vi;
1195
- int ui0, ui1;
1196
1210
memcpy (&vi, &bq4_0->qs [sizeof (int ) * (iqs + 0 )], sizeof (int ));
1197
- memcpy (& ui0, &bq8_0 ->qs [sizeof (int ) * (iqs + 0 )], sizeof ( int ) );
1198
- memcpy (& ui1, &bq8_0 ->qs [sizeof (int ) * (iqs + 4 )], sizeof ( int ) );
1211
+ const int ui0 = *(( int *) &bq8_1 ->qs [sizeof (int ) * (iqs + 0 )]);
1212
+ const int ui1 = *(( int *) &bq8_1 ->qs [sizeof (int ) * (iqs + QI4_0)] );
1199
1213
1200
- const float d = bq4_0->d * bq8_0 ->d ;
1214
+ const float d = bq4_0->d * bq8_1 ->d ;
1201
1215
1202
1216
const int vi0 = __vsub4 ((vi >> 0 ) & 0x0F0F0F0F , 0x08080808 );
1203
1217
const int vi1 = __vsub4 ((vi >> 4 ) & 0x0F0F0F0F , 0x08080808 );
1204
1218
1205
- const int sumi0 = __dp4a (vi0, ui0, 0 );
1206
- const int sumi1 = __dp4a (vi1, ui1, 0 );
1219
+ int sumi = __dp4a (vi0, ui0, 0 );
1220
+ sumi = __dp4a (vi1, ui1, sumi );
1207
1221
1208
- return (sumi0 + sumi1) *d;
1222
+ return sumi *d;
1209
1223
1210
1224
}
1211
1225
1212
- static __device__ float vec_dot_q4_1_q8_0 (const void * vbq, const block_q8_0 * bq8_0 , const int iqs) {
1226
+ static __device__ __forceinline__ float vec_dot_q4_1_q8_1 (const void * vbq, const block_q8_1 * bq8_1 , const int iqs) {
1213
1227
const block_q4_1 * bq4_1 = (const block_q4_1 *) vbq;
1214
1228
1215
- int vi = *((int *) &bq4_1->qs [sizeof (int ) * (iqs + 0 )]);
1216
- int ui0, ui1;
1217
- memcpy (&ui0, &bq8_0->qs [sizeof (int ) * (iqs + 0 )], sizeof (int ));
1218
- memcpy (&ui1, &bq8_0->qs [sizeof (int ) * (iqs + 4 )], sizeof (int ));
1229
+ const int vi = *((int *) &bq4_1->qs [sizeof (int ) * (iqs + 0 )]);
1230
+ const int ui0 = *((int *) &bq8_1->qs [sizeof (int ) * (iqs + 0 )]);
1231
+ const int ui1 = *((int *) &bq8_1->qs [sizeof (int ) * (iqs + QI4_1)]);
1219
1232
1220
- const float d4_1 = bq4_1->d ;
1221
- const float m = bq4_1->m ;
1222
- const float d8_0 = bq8_0-> d ;
1233
+ const float d = bq4_1-> d * bq8_1 ->d ;
1234
+ const float m = bq4_1->m ;
1235
+ const float s = bq8_1-> s ;
1223
1236
1224
1237
const int vi0 = (vi >> 0 ) & 0x0F0F0F0F ;
1225
1238
const int vi1 = (vi >> 4 ) & 0x0F0F0F0F ;
1226
1239
1227
- const int sumi0 = __dp4a (vi0, ui0, 0 );
1228
- const int sumi1 = __dp4a (vi1, ui1, 0 );
1229
-
1230
- const int sumi2 = __dp4a (0x01010101 , ui0, 0 );
1231
- const int sumi3 = __dp4a (0x01010101 , ui1, 0 );
1240
+ int sumi = __dp4a (vi0, ui0, 0 );
1241
+ sumi = __dp4a (vi1, ui1, sumi);
1232
1242
1233
- return (sumi0 + sumi1)*d4_1*d8_0 + (sumi2 + sumi3)*m*d8_0 ;
1243
+ return sumi*d + m*s / QI4_1 ;
1234
1244
1235
1245
}
1236
1246
@@ -1263,8 +1273,6 @@ static __global__ void mul_mat_vec_q(const void * vx, const void * vy, float * d
1263
1273
return ;
1264
1274
}
1265
1275
1266
- const int tid = threadIdx .x ;
1267
-
1268
1276
const int blocks_per_row = ncols / qk;
1269
1277
const int blocks_per_warp = WARP_SIZE * sizeof (int )*2 /qk;
1270
1278
const int ints_per_block = qk / (2 * sizeof (int ));
@@ -1273,14 +1281,14 @@ static __global__ void mul_mat_vec_q(const void * vx, const void * vy, float * d
1273
1281
float tmp = 0 .0f ;
1274
1282
1275
1283
const block_q_t * x = (const block_q_t *) vx;
1276
- const block_q8_0 * y = (const block_q8_0 *) vy;
1284
+ const block_q8_1 * y = (const block_q8_1 *) vy;
1277
1285
1278
1286
for (int i = 0 ; i < blocks_per_row; i += blocks_per_warp) {
1279
- const int ibx = row*blocks_per_row + i + tid /ints_per_block; // x block index
1287
+ const int ibx = row*blocks_per_row + i + threadIdx . x /ints_per_block; // x block index
1280
1288
1281
- const int iby = i + tid /ints_per_block;
1289
+ const int iby = i + threadIdx . x /ints_per_block;
1282
1290
1283
- const int iqs = tid % ints_per_block;
1291
+ const int iqs = threadIdx . x % ints_per_block;
1284
1292
1285
1293
tmp += vec_dot_q_cuda (&x[ibx], &y[iby], iqs);
1286
1294
}
@@ -1292,7 +1300,7 @@ static __global__ void mul_mat_vec_q(const void * vx, const void * vy, float * d
1292
1300
tmp += __shfl_xor_sync (0xffffffff , tmp, mask, 32 );
1293
1301
}
1294
1302
1295
- if (tid == 0 ) {
1303
+ if (threadIdx . x == 0 ) {
1296
1304
dst[row] = tmp;
1297
1305
}
1298
1306
}
@@ -1612,9 +1620,9 @@ static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, con
1612
1620
rms_norm_f32<<<nrows, block_dims, 0 , stream>>> (x, dst, ncols);
1613
1621
}
1614
1622
1615
- static void quantize_row_q8_0_cuda (const float * x, void * vy, const int k, cudaStream_t stream) {
1623
+ static void quantize_row_q8_1_cuda (const float * x, void * vy, const int k, cudaStream_t stream) {
1616
1624
const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1 ) / CUDA_DEQUANTIZE_BLOCK_SIZE;
1617
- quantize_q8_0 <<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0 , stream>>> (x, vy, k);
1625
+ quantize_q8_1 <<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0 , stream>>> (x, vy, k);
1618
1626
}
1619
1627
1620
1628
static void dequantize_row_q4_0_cuda (const void * vx, float * y, const int k, cudaStream_t stream) {
@@ -1770,21 +1778,21 @@ static void dequantize_mul_mat_vec_q6_K_cuda(const void * vx, const float * y, f
1770
1778
dequantize_mul_mat_vec_q6_k<<<block_nums, block_dims, 0 , stream>>> (vx, y, dst, ncols, nrows);
1771
1779
}
1772
1780
1773
- static void mul_mat_vec_q4_0_q8_0_cuda (const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
1781
+ static void mul_mat_vec_q4_0_q8_1_cuda (const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
1774
1782
GGML_ASSERT (ncols % GGML_CUDA_DMMV_X == 0 );
1775
1783
const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1 ) / GGML_CUDA_DMMV_Y;
1776
1784
const dim3 block_nums (1 , block_num_y, 1 );
1777
1785
const dim3 block_dims (WARP_SIZE, GGML_CUDA_DMMV_Y, 1 );
1778
- mul_mat_vec_q<QK4_0, block_q4_0, vec_dot_q4_0_q8_0 >
1786
+ mul_mat_vec_q<QK4_0, block_q4_0, vec_dot_q4_0_q8_1 >
1779
1787
<<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols, nrows);
1780
1788
}
1781
1789
1782
- static void mul_mat_vec_q4_1_q8_0_cuda (const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
1790
+ static void mul_mat_vec_q4_1_q8_1_cuda (const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
1783
1791
GGML_ASSERT (ncols % GGML_CUDA_DMMV_X == 0 );
1784
1792
const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1 ) / GGML_CUDA_DMMV_Y;
1785
1793
const dim3 block_nums (1 , block_num_y, 1 );
1786
1794
const dim3 block_dims (WARP_SIZE, GGML_CUDA_DMMV_Y, 1 );
1787
- mul_mat_vec_q<QK4_0, block_q4_1, vec_dot_q4_1_q8_0 >
1795
+ mul_mat_vec_q<QK4_0, block_q4_1, vec_dot_q4_1_q8_1 >
1788
1796
<<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols, nrows);
1789
1797
}
1790
1798
@@ -2302,14 +2310,14 @@ inline void ggml_cuda_op_mul_mat_vec_q(
2302
2310
2303
2311
size_t as;
2304
2312
void * src1_q8_0 = ggml_cuda_pool_malloc (ne00*sizeof (block_q8_0)/QK8_0, &as);
2305
- quantize_row_q8_0_cuda (src1_ddf_i, src1_q8_0, ne00, cudaStream_main);
2313
+ quantize_row_q8_1_cuda (src1_ddf_i, src1_q8_0, ne00, cudaStream_main);
2306
2314
2307
2315
switch (src0->type ) {
2308
2316
case GGML_TYPE_Q4_0:
2309
- mul_mat_vec_q4_0_q8_0_cuda (src0_ddq_i, src1_q8_0, dst_ddf_i, ne00, nrows, cudaStream_main);
2317
+ mul_mat_vec_q4_0_q8_1_cuda (src0_ddq_i, src1_q8_0, dst_ddf_i, ne00, nrows, cudaStream_main);
2310
2318
break ;
2311
2319
case GGML_TYPE_Q4_1:
2312
- mul_mat_vec_q4_1_q8_0_cuda (src0_ddq_i, src1_q8_0, dst_ddf_i, ne00, nrows, cudaStream_main);
2320
+ mul_mat_vec_q4_1_q8_1_cuda (src0_ddq_i, src1_q8_0, dst_ddf_i, ne00, nrows, cudaStream_main);
2313
2321
break ;
2314
2322
default :
2315
2323
GGML_ASSERT (false );
0 commit comments