Skip to content

Commit 5528a5f

Browse files
q3_K works
1 parent 1861a36 commit 5528a5f

File tree

1 file changed

+78
-7
lines changed

1 file changed

+78
-7
lines changed

ggml-cuda.cu

Lines changed: 78 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,8 @@ typedef struct {
153153
} block_q2_K;
154154
static_assert(sizeof(block_q2_K) == 2*sizeof(ggml_fp16_t) + QK_K/16 + QK_K/4, "wrong q2_K block size/padding");
155155

156+
#define QR3_K 4
157+
#define QI3_K (QK_K / (4*QR3_K))
156158
typedef struct {
157159
uint8_t hmask[QK_K/8]; // quants - high bit
158160
uint8_t qs[QK_K/4]; // quants - low 2 bits
@@ -1259,7 +1261,8 @@ static __global__ void dequantize_block(const void * __restrict__ vx, float * __
12591261
y[iybs + iqs + y_offset] = v.y;
12601262
}
12611263

1262-
static __device__ __forceinline__ float vec_dot_q4_0_q8_1(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) {
1264+
static __device__ __forceinline__ float vec_dot_q4_0_q8_1(
1265+
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) {
12631266
#if __CUDA_ARCH__ >= 610 // lowest compute capability for integer intrinsics
12641267
const block_q4_0 * bq4_0 = (const block_q4_0 *) vbq;
12651268

@@ -1284,7 +1287,8 @@ static __device__ __forceinline__ float vec_dot_q4_0_q8_1(const void * __restric
12841287
#endif // __CUDA_ARCH__ >= 610
12851288
}
12861289

1287-
static __device__ __forceinline__ float vec_dot_q4_1_q8_1(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) {
1290+
static __device__ __forceinline__ float vec_dot_q4_1_q8_1(
1291+
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) {
12881292
#if __CUDA_ARCH__ >= 610 // lowest compute capability for integer intrinsics
12891293
const block_q4_1 * bq4_1 = (const block_q4_1 *) vbq;
12901294

@@ -1309,7 +1313,8 @@ static __device__ __forceinline__ float vec_dot_q4_1_q8_1(const void * __restric
13091313
#endif // __CUDA_ARCH__ >= 610
13101314
}
13111315

1312-
static __device__ __forceinline__ float vec_dot_q5_0_q8_1(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) {
1316+
static __device__ __forceinline__ float vec_dot_q5_0_q8_1(
1317+
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) {
13131318
#if __CUDA_ARCH__ >= 610 // lowest compute capability for integer intrinsics
13141319
const block_q5_0 * bq5_0 = (const block_q5_0 *) vbq;
13151320

@@ -1344,7 +1349,8 @@ static __device__ __forceinline__ float vec_dot_q5_0_q8_1(const void * __restric
13441349
#endif // __CUDA_ARCH__ >= 610
13451350
}
13461351

1347-
static __device__ __forceinline__ float vec_dot_q5_1_q8_1(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) {
1352+
static __device__ __forceinline__ float vec_dot_q5_1_q8_1(
1353+
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) {
13481354
#if __CUDA_ARCH__ >= 610 // lowest compute capability for integer intrinsics
13491355
const block_q5_1 * bq5_1 = (const block_q5_1 *) vbq;
13501356

@@ -1378,7 +1384,8 @@ static __device__ __forceinline__ float vec_dot_q5_1_q8_1(const void * __restric
13781384
#endif // __CUDA_ARCH__ >= 610
13791385
}
13801386

1381-
static __device__ __forceinline__ float vec_dot_q8_0_q8_1(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) {
1387+
static __device__ __forceinline__ float vec_dot_q8_0_q8_1(
1388+
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) {
13821389
#if __CUDA_ARCH__ >= 610 // lowest compute capability for integer intrinsics
13831390
const block_q8_0 * bq8_0 = (const block_q8_0 *) vbq;
13841391

@@ -1432,6 +1439,58 @@ static __device__ __forceinline__ float vec_dot_q2_K_q8_1(
14321439
#endif // __CUDA_ARCH__ >= 610
14331440
}
14341441

1442+
static __device__ __forceinline__ float vec_dot_q3_K_q8_1(
1443+
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) {
1444+
1445+
#if __CUDA_ARCH__ >= 610 // lowest compute capability for integer intrinsics
1446+
const block_q3_K * bq3_K = (const block_q3_K *) vbq;
1447+
1448+
const int bq8_offset = 4 * (iqs / (QI3_K/2));
1449+
1450+
float sumf = 0.0f;
1451+
1452+
const float d = bq3_K->d;
1453+
1454+
int vil;
1455+
memcpy(&vil, &bq3_K->qs[sizeof(int) * iqs], sizeof(int));
1456+
1457+
int vih;
1458+
memcpy(&vih, &bq3_K->hmask[sizeof(int) * (iqs % (QI3_K/2))], sizeof(int));
1459+
vih = ~vih;
1460+
vih >>= bq8_offset;
1461+
1462+
for (int i = 0; i < 4; ++i) {
1463+
const int isc = iqs - iqs%8 + (iqs%8) / 4 + 2*i;
1464+
1465+
const int isc_low = isc % (QK_K/32);
1466+
const int sc_shift_low = 4 * (isc / (QK_K/32));
1467+
const int sc_low = (bq3_K->scales[isc_low] >> sc_shift_low) & 0xF;
1468+
1469+
const int isc_high = isc % (QK_K/64);
1470+
const int sc_shift_high = 2 * (isc / (QK_K/64));
1471+
const int sc_high = ((bq3_K->scales[(QK_K/32) + isc_high] >> sc_shift_high) & 3) << 4;
1472+
1473+
const int sc = (sc_low | sc_high) - 32;
1474+
1475+
const block_q8_1 * bq8i = bq8_1 + bq8_offset + i;
1476+
const int uii = *((int*) &bq8i->qs[sizeof(int) * (iqs%8)]);
1477+
const float d8i = bq8i->d;
1478+
1479+
const int viil = (vil >> (2*i)) & 0x03030303;
1480+
1481+
const int viih = ((vih >> i) << 2) & 0x04040404;
1482+
1483+
const int vii = __vsubss4(viil, viih);
1484+
1485+
sumf += d8i * (__dp4a(vii, uii, 0) * sc);
1486+
}
1487+
1488+
return d*sumf;
1489+
#else
1490+
return 0.0f; // only to satisfy the compiler
1491+
#endif // __CUDA_ARCH__ >= 610
1492+
}
1493+
14351494
template <int qk, int qi, typename block_q_t, vec_dot_q_cuda_t vec_dot_q_cuda>
14361495
static __global__ void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const int ncols, const int nrows) {
14371496
const int row = blockIdx.y*blockDim.y + threadIdx.y;
@@ -1999,6 +2058,15 @@ static void mul_mat_vec_q2_K_q8_1_cuda(const void * vx, const void * vy, float *
19992058
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
20002059
}
20012060

2061+
static void mul_mat_vec_q3_K_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
2062+
GGML_ASSERT(ncols % QK_K == 0);
2063+
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
2064+
const dim3 block_nums(1, block_num_y, 1);
2065+
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
2066+
mul_mat_vec_q<QK_K, QI3_K, block_q3_K, vec_dot_q3_K_q8_1>
2067+
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
2068+
}
2069+
20022070
static void convert_fp16_to_fp32_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
20032071
const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
20042072
dequantize_block<1, 1, convert_f16><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
@@ -2462,8 +2530,8 @@ inline void ggml_cuda_op_mul_mat_vec(
24622530
src0->type == GGML_TYPE_Q5_0 ||
24632531
src0->type == GGML_TYPE_Q5_1 ||
24642532
src0->type == GGML_TYPE_Q8_0 ||
2465-
src0->type == GGML_TYPE_Q2_K;
2466-
// src0->type == GGML_TYPE_Q3_K ||
2533+
src0->type == GGML_TYPE_Q2_K ||
2534+
src0->type == GGML_TYPE_Q3_K;
24672535
// src0->type == GGML_TYPE_Q4_K ||
24682536
// src0->type == GGML_TYPE_Q5_K ||
24692537
// src0->type == GGML_TYPE_Q5_K;
@@ -2497,6 +2565,9 @@ inline void ggml_cuda_op_mul_mat_vec(
24972565
case GGML_TYPE_Q2_K:
24982566
mul_mat_vec_q2_K_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, cudaStream_main);
24992567
break;
2568+
case GGML_TYPE_Q3_K:
2569+
mul_mat_vec_q3_K_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, cudaStream_main);
2570+
break;
25002571
default:
25012572
GGML_ASSERT(false);
25022573
break;

0 commit comments

Comments
 (0)