Skip to content

Commit bca8a68

Browse files
mmq template, q4_1
1 parent 31f229c commit bca8a68

File tree

1 file changed

+139
-41
lines changed

1 file changed

+139
-41
lines changed

ggml-cuda.cu

Lines changed: 139 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,7 @@ static_assert(sizeof(block_q4_0) == sizeof(ggml_fp16_t) + QK4_0 / 2, "wrong q4_0
8787
#define QR4_1 2
8888
#define QI4_1 (QK4_1 / (4 * QR4_1))
8989
typedef struct {
90-
half d; // delta
91-
half m; // min
90+
half2 dm; // dm.x = delta, dm.y = min
9291
uint8_t qs[QK4_1 / 2]; // nibbles / quants
9392
} block_q4_1;
9493
static_assert(sizeof(block_q4_1) == sizeof(ggml_fp16_t) * 2 + QK4_1 / 2, "wrong q4_1 block size/padding");
@@ -133,6 +132,13 @@ typedef struct {
133132
static_assert(sizeof(block_q8_1) == 2*sizeof(ggml_fp16_t) + QK8_0, "wrong q8_1 block size/padding");
134133

135134
typedef float (*vec_dot_q_cuda_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs);
135+
typedef void (*allocate_tiles_cuda_t)(int ** x_ql, half2 ** x_dm, int ** x_qh, int8_t ** x_sc);
136+
typedef void (*load_tiles_cuda_t)(
137+
const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
138+
int8_t * __restrict__ x_sc, const int & i, const int & k, const int & blocks_per_row);
139+
typedef float (*vec_dot_q_mul_mat_cuda_t)(
140+
const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int8_t * __restrict__ x_sc,
141+
const int * __restrict__ y_qs, const half2 * __restrict__ y_ms, const int & i, const int & j, const int & k);
136142

137143
//================================= k-quants
138144

@@ -380,8 +386,8 @@ static __device__ __forceinline__ void dequantize_q4_0(const void * vx, const in
380386
static __device__ __forceinline__ void dequantize_q4_1(const void * vx, const int ib, const int iqs, dfloat2 & v){
381387
const block_q4_1 * x = (const block_q4_1 *) vx;
382388

383-
const dfloat d = x[ib].d;
384-
const dfloat m = x[ib].m;
389+
const dfloat d = x[ib].dm.x;
390+
const dfloat m = x[ib].dm.y;
385391

386392
const int vui = x[ib].qs[iqs];
387393

@@ -1313,33 +1319,111 @@ static __device__ __forceinline__ float vec_dot_q4_0_q8_1(
13131319
return vec_dot_q4_0_q8_1_impl(vi, ui0, ui1, bq4_0->d, bq8_1->ds);
13141320
}
13151321

1316-
static __device__ __forceinline__ float vec_dot_q4_1_q8_1(
1317-
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
1322+
static __device__ __forceinline__ void allocate_tiles_q4_0(int ** x_ql, half2 ** x_dm, int ** x_qh, int8_t ** x_sc) {
13181323

1319-
#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
1320-
const block_q4_1 * bq4_1 = (const block_q4_1 *) vbq;
1324+
__shared__ int tile_x_qs[(2*WARP_SIZE) * (WARP_SIZE + 1)];
1325+
__shared__ half2 tile_x_d[(2*WARP_SIZE) * (WARP_SIZE/QI4_0)];
13211326

1322-
const int vi = *((int *) &bq4_1->qs[sizeof(int) * (iqs + 0)]);
1323-
const int ui0 = *((int *) &bq8_1->qs[sizeof(int) * (iqs + 0)]);
1324-
const int ui1 = *((int *) &bq8_1->qs[sizeof(int) * (iqs + QI4_1)]);
1327+
*x_ql = tile_x_qs;
1328+
*x_dm = tile_x_d;
1329+
}
13251330

1326-
const float d = __half2float(bq4_1->d) * __half2float(bq8_1->ds.x);
1327-
const float m = bq4_1->m;
1328-
const float s = bq8_1->ds.y;
1331+
static __device__ __forceinline__ void load_tiles_q4_0(
1332+
const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
1333+
int8_t * __restrict__ x_sc, const int & i, const int & k, const int & blocks_per_row) {
1334+
1335+
const int kbx = k / QI4_0;
1336+
const int kqsx = sizeof(int) * (k % QI4_0);
1337+
1338+
const block_q4_0 * bx = ((block_q4_0 *) vx) + i*blocks_per_row + kbx;
1339+
1340+
memcpy(&x_ql[i * WARP_SIZE + i + k], &bx->qs[kqsx], sizeof(int));
1341+
x_dm[i * (WARP_SIZE / QI4_0) + kbx].x = bx->d;
1342+
}
13291343

1344+
static __device__ __forceinline__ float vec_dot_q4_0_q8_1_mul_mat(
1345+
const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int8_t * __restrict__ x_sc,
1346+
const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
1347+
1348+
const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2));
1349+
1350+
return vec_dot_q4_0_q8_1_impl(
1351+
x_ql[i * WARP_SIZE + i + k], y_qs[j * (2*WARP_SIZE) + kyqs], y_qs[j * (2*WARP_SIZE) + kyqs + (QI8_1/2)],
1352+
x_dm[i * (WARP_SIZE/QI4_0) + k/QI4_0].x, y_ds[j * (2*WARP_SIZE/QI8_1) + 2*k/QI8_1]);
1353+
}
1354+
1355+
static __device__ __forceinline__ float vec_dot_q4_1_q8_1_impl(
1356+
const int & vi, const int & ui0, const int & ui1, const half2 & dm4, const half2 & ds8) {
1357+
1358+
#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
13301359
const int vi0 = (vi >> 0) & 0x0F0F0F0F;
13311360
const int vi1 = (vi >> 4) & 0x0F0F0F0F;
13321361

13331362
// SIMD dot product of quantized values
13341363
int sumi = __dp4a(vi0, ui0, 0);
13351364
sumi = __dp4a(vi1, ui1, sumi);
13361365

1337-
return sumi*d + m*s / QI4_1; // scale sum by QI4_1 because there are QI4_1 threads working on this block
1366+
#ifdef GGML_CUDA_DMMV_F16
1367+
const half2 tmp = __hmul2(dm4, ds8);
1368+
const float d4d8 = __half2float(tmp.x);
1369+
const float m4s8 = __half2float(tmp.y);
1370+
#else
1371+
const float d4d8 = __half2float(dm4.x) * __half2float(ds8.x);
1372+
const float m4s8 = __half2float(dm4.y) * __half2float(ds8.y);
1373+
#endif // GGML_CUDA_DMMV_F16
1374+
1375+
// scale second part of sum by QI8_1/QR4_1 to compensate for multiple threads adding it
1376+
return sumi * d4d8 + m4s8 / (QI8_1 / QR4_1);
13381377
#else
13391378
return 0.0f; // only to satisfy the compiler
13401379
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
13411380
}
13421381

1382+
static __device__ __forceinline__ float vec_dot_q4_1_q8_1(
1383+
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
1384+
1385+
const block_q4_1 * bq4_1 = (const block_q4_1 *) vbq;
1386+
1387+
const int vi = *((int *) &bq4_1->qs[sizeof(int) * (iqs + 0)]);
1388+
const int ui0 = *((int *) &bq8_1->qs[sizeof(int) * (iqs + 0)]);
1389+
const int ui1 = *((int *) &bq8_1->qs[sizeof(int) * (iqs + QI4_1)]);
1390+
1391+
return vec_dot_q4_1_q8_1_impl(vi, ui0, ui1, bq4_1->dm, bq8_1->ds);
1392+
}
1393+
1394+
static __device__ __forceinline__ void allocate_tiles_q4_1(int ** x_ql, half2 ** x_dm, int ** x_qh, int8_t ** x_sc) {
1395+
1396+
__shared__ int tile_x_qs[(2*WARP_SIZE) * (WARP_SIZE + 1)];
1397+
__shared__ half2 tile_x_dm[(2*WARP_SIZE) * (WARP_SIZE/QI4_1)];
1398+
1399+
*x_ql = tile_x_qs;
1400+
*x_dm = tile_x_dm;
1401+
}
1402+
1403+
static __device__ __forceinline__ void load_tiles_q4_1(
1404+
const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
1405+
int8_t * __restrict__ x_sc, const int & i, const int & k, const int & blocks_per_row) {
1406+
1407+
const int kbx = k / QI4_1;
1408+
const int kqsx = sizeof(int) * (k % QI4_1);
1409+
1410+
const block_q4_1 * bx = ((block_q4_1 *) vx) + i*blocks_per_row + kbx;
1411+
1412+
x_ql[i * WARP_SIZE + i + k] = *((int *) &bx->qs[kqsx]);
1413+
x_dm[i * (WARP_SIZE / QI4_1) + kbx] = bx->dm;
1414+
}
1415+
1416+
static __device__ __forceinline__ float vec_dot_q4_1_q8_1_mul_mat(
1417+
const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int8_t * __restrict__ x_sc,
1418+
const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
1419+
1420+
const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2));
1421+
1422+
return vec_dot_q4_1_q8_1_impl(
1423+
x_ql[i * WARP_SIZE + i + k], y_qs[j * (2*WARP_SIZE) + kyqs], y_qs[j * (2*WARP_SIZE) + kyqs + (QI8_1/2)],
1424+
x_dm[i * (WARP_SIZE/QI4_1) + k/QI4_1], y_ds[j * (2*WARP_SIZE/QI8_1) + 2*k/QI8_1]);
1425+
}
1426+
13431427
static __device__ __forceinline__ float vec_dot_q5_0_q8_1(
13441428
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
13451429

@@ -1647,15 +1731,17 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1(
16471731
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
16481732
}
16491733

1734+
template <int qk, int qi, typename block_q_t,
1735+
allocate_tiles_cuda_t allocate_tiles, load_tiles_cuda_t load_tiles, vec_dot_q_mul_mat_cuda_t vec_dot>
16501736
static __global__ void mul_mat_q(
16511737
const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
16521738
const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_dst) {
16531739

1654-
const block_q4_0 * x = (const block_q4_0 *) vx;
1740+
const block_q_t * x = (const block_q_t *) vx;
16551741
const block_q8_1 * y = (const block_q8_1 *) vy;
16561742

1657-
const int blocks_per_row = ncols_x / QK4_0;
1658-
const int blocks_per_warp = WARP_SIZE / QI4_0;
1743+
const int blocks_per_row = ncols_x / qk;
1744+
const int blocks_per_warp = WARP_SIZE / qi;
16591745

16601746
const int & ncols_dst = ncols_y;
16611747

@@ -1669,20 +1755,23 @@ static __global__ void mul_mat_q(
16691755
const int col_dst_0 = blockIdx.y*WARP_SIZE;
16701756
const int & col_y_0 = col_dst_0;
16711757

1672-
__shared__ int tile_x_qs[2*WARP_SIZE][WARP_SIZE + 1];
1673-
__shared__ half tile_x_d[2*WARP_SIZE][WARP_SIZE/QI4_0];
1674-
__shared__ int tile_y_qs[WARP_SIZE][2*WARP_SIZE];
1675-
__shared__ half2 tile_y_ds[WARP_SIZE][2*WARP_SIZE/QI8_1];
1758+
int * tile_x_ql = nullptr;
1759+
half2 * tile_x_dm = nullptr;
1760+
int * tile_x_qh = nullptr;
1761+
int8_t * tile_x_sc = nullptr;
1762+
1763+
allocate_tiles(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc);
1764+
1765+
__shared__ int tile_y_qs[(WARP_SIZE) * (2*WARP_SIZE)];
1766+
__shared__ half2 tile_y_ds[(WARP_SIZE) * (2*WARP_SIZE/QI8_1)];
1767+
16761768
float sum[2][4] = {0.0f};
16771769

16781770
for (int ib0 = 0; ib0 < blocks_per_row; ib0 += blocks_per_warp) {
1679-
const int ibx = tid_x / QI4_0;
1680-
const int iqsx = sizeof(int) * (tid_x % QI4_0);
16811771

1682-
for (int j = 0; j < 2*WARP_SIZE; j += 8) {
1683-
const block_q4_0 * __restrict__ bx = &x[(row_x_0 + j + tid_y)*blocks_per_row + ib0 + ibx];
1684-
memcpy(&tile_x_qs[j + tid_y][tid_x], &bx->qs[iqsx], sizeof(int));
1685-
tile_x_d[j + tid_y][ibx] = bx->d;
1772+
for (int i = 0; i < 2*WARP_SIZE; i += 8) {
1773+
load_tiles(x + row_x_0*blocks_per_row + ib0, tile_x_ql, tile_x_dm, tile_x_qh, tile_x_sc,
1774+
i + tid_y, tid_x, blocks_per_row);
16861775
}
16871776

16881777
const int iby0 = tid_x / QI8_1;
@@ -1692,26 +1781,23 @@ static __global__ void mul_mat_q(
16921781
for (int i = 0; i < WARP_SIZE; i += 8) {
16931782
const block_q8_1 * __restrict__ by0 = &y[(col_y_0 + tid_y + i)*blocks_per_row + ib0 + iby0];
16941783

1695-
tile_y_qs[tid_y + i][tid_x] = *((int *) &by0->qs[iqsy]);
1696-
tile_y_ds[tid_y + i][iby0] = by0->ds;
1784+
tile_y_qs[(tid_y + i) * (2*WARP_SIZE) + tid_x] = *((int *) &by0->qs[iqsy]);
1785+
tile_y_ds[(tid_y + i) * (2*WARP_SIZE/QI8_1) + iby0] = by0->ds;
16971786

16981787
const block_q8_1 * __restrict__ by1 = &y[(col_y_0 + tid_y + i)*blocks_per_row + ib0 + iby1];
16991788

1700-
tile_y_qs[tid_y + i][tid_x + WARP_SIZE] = *((int *) &by1->qs[iqsy]);
1701-
tile_y_ds[tid_y + i][iby1] = by1->ds;
1789+
tile_y_qs[(tid_y + i) * (2*WARP_SIZE) + tid_x + WARP_SIZE] = *((int *) &by1->qs[iqsy]);
1790+
tile_y_ds[(tid_y + i) * (2*WARP_SIZE/QI8_1) + iby1] = by1->ds;
17021791
}
17031792

17041793
__syncthreads();
17051794

17061795
for (int k = 0; k < WARP_SIZE; ++k) {
1707-
const int iqsy = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2));
17081796
for (int j = 0; j < WARP_SIZE; j += 8) {
1709-
sum[0][j/8] += vec_dot_q4_0_q8_1_impl(
1710-
tile_x_qs[tid_x][k], tile_y_qs[tid_y + j][iqsy + 0], tile_y_qs[tid_y + j][iqsy + (QI8_1/2)],
1711-
tile_x_d[tid_x][k / QI4_0], tile_y_ds[tid_y + j][2 * k / QI8_1]);
1712-
sum[1][j/8] += vec_dot_q4_0_q8_1_impl(
1713-
tile_x_qs[tid_x + WARP_SIZE][k], tile_y_qs[tid_y + j][iqsy + 0], tile_y_qs[tid_y + j][iqsy + (QI8_1/2)],
1714-
tile_x_d[tid_x + WARP_SIZE][k / QI4_0], tile_y_ds[tid_y + j][2 * k / QI8_1]);
1797+
sum[0][j/8] += vec_dot(tile_x_ql, tile_x_dm, tile_x_qh, tile_x_sc, tile_y_qs, tile_y_ds,
1798+
tid_x, tid_y + j, k);
1799+
sum[1][j/8] += vec_dot(tile_x_ql, tile_x_dm, tile_x_qh, tile_x_sc, tile_y_qs, tile_y_ds,
1800+
tid_x + WARP_SIZE, tid_y + j, k);
17151801
}
17161802
}
17171803

@@ -2425,7 +2511,15 @@ static void ggml_mul_mat_q4_0_q8_1_cuda(const void * vx, const void * vy, float
24252511
const int block_num_y = (ncols_y + WARP_SIZE - 1) / WARP_SIZE;
24262512
const dim3 block_nums(block_num_x, block_num_y, 1);
24272513
const dim3 block_dims(WARP_SIZE, WARP_SIZE/4, 1);
2428-
mul_mat_q<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_dst);
2514+
mul_mat_q<QK4_0, QI4_0, block_q4_0, allocate_tiles_q4_0, load_tiles_q4_0, vec_dot_q4_0_q8_1_mul_mat><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_dst);
2515+
}
2516+
2517+
static void ggml_mul_mat_q4_1_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_dst, cudaStream_t stream){
2518+
const int block_num_x = (nrows_x + 2*WARP_SIZE - 1) / (2*WARP_SIZE);
2519+
const int block_num_y = (ncols_y + WARP_SIZE - 1) / WARP_SIZE;
2520+
const dim3 block_nums(block_num_x, block_num_y, 1);
2521+
const dim3 block_dims(WARP_SIZE, WARP_SIZE/4, 1);
2522+
mul_mat_q<QK4_1, QI4_1, block_q4_1, allocate_tiles_q4_1, load_tiles_q4_1, vec_dot_q4_1_q8_1_mul_mat><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_dst);
24292523
}
24302524

24312525
static void ggml_mul_mat_p021_f16_f32_cuda(const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nchannels_x, cudaStream_t stream) {
@@ -2890,6 +2984,9 @@ inline void ggml_cuda_op_mul_mat_q(
28902984
case GGML_TYPE_Q4_0:
28912985
ggml_mul_mat_q4_0_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, i01_diff, ne11, nrows_dst, cudaStream_main);
28922986
break;
2987+
case GGML_TYPE_Q4_1:
2988+
ggml_mul_mat_q4_1_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, i01_diff, ne11, nrows_dst, cudaStream_main);
2989+
break;
28932990
default:
28942991
GGML_ASSERT(false);
28952992
break;
@@ -3639,7 +3736,8 @@ void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_
36393736
if (src1->ne[1] == 1 && src0->ne[0] % GGML_CUDA_DMMV_X == 0) {
36403737
ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul_mat_vec, false, false);
36413738
} else {
3642-
if (src0->type == GGML_TYPE_Q4_0) {
3739+
// if (src0->type == GGML_TYPE_Q4_0) {
3740+
if (src0->type == GGML_TYPE_Q4_0 || src0->type == GGML_TYPE_Q4_1) {
36433741
ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul_mat_q, false, false);
36443742
} else {
36453743
ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, true, false);

0 commit comments

Comments
 (0)