Skip to content

Commit a3e6d62

Browse files
committed
cuda : alternative q4_q8 kernel
1 parent e7b9d97 commit a3e6d62

File tree

1 file changed

+94
-3
lines changed

1 file changed

+94
-3
lines changed

ggml-cuda.cu

Lines changed: 94 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,92 @@ template <int block_size> static __global__ void dequantize_mul_mat_q4_0(const v
274274
}
275275
}
276276

277+
template <int NT, int NR> static __global__ void dequantize_mul_mat_q4_0_test(const void * vx, const void * vy, float * dst, const int ncols, const int nrows) {
278+
const block_q4_0 * x = (const block_q4_0 *) vx;
279+
const block_q8_0 * y = (const block_q8_0 *) vy;
280+
281+
const int bid = blockIdx.x;
282+
const int tid = threadIdx.x;
283+
284+
__shared__ float tmp[NR][NT];
285+
for (int i = 0; i < NR; ++i) {
286+
tmp[i][tid] = 0.0f;
287+
}
288+
289+
const int nbc = (ncols + 16*NT - 1)/(16*NT);
290+
const int nbm = ncols/QK8_0;
291+
292+
uint64_t xa0;
293+
uint64_t xa1;
294+
295+
const int8_t * xb0 = (const int8_t *) &xa0;
296+
const int8_t * xb1 = (const int8_t *) &xa1;
297+
298+
for (int ibc = 0; ibc < nbc; ++ibc) {
299+
const int iyb = (ibc*(16*NT) + 16*tid)/QK8_0;
300+
const int iyq = (ibc*(16*NT) + 16*tid)%QK8_0;
301+
302+
if (iyb >= nbm) {
303+
continue;
304+
}
305+
306+
const int8_t * yb = (const int8_t *) &y[iyb].qs[iyq];
307+
308+
const float dy = y[iyb].d;
309+
310+
for (int ibr = 0; ibr < NR; ++ibr) {
311+
const int ir = bid*NR + ibr;
312+
if (ir >= nrows) {
313+
continue;
314+
}
315+
316+
// block offset
317+
const int ixo = (ir*ncols)/QK4_0 + iyb;
318+
319+
memcpy(&xa0, &x[ixo].qs[iyq/2 + 0], sizeof(uint64_t));
320+
xa1 = xa0;
321+
322+
xa0 = (xa0 ) & 0x0F0F0F0F0F0F0F0F;
323+
xa1 = (xa1 >> 4) & 0x0F0F0F0F0F0F0F0F;
324+
325+
const float dx = x[ixo].d;
326+
327+
// the (int) cast is probably unnecessary, but just to make sure the result is accumulated in 32 bits
328+
tmp[ibr][tid] += (
329+
((int)(xb0[0] - 8))*yb[0] + ((int)(xb1[0] - 8))*yb[1] +
330+
((int)(xb0[1] - 8))*yb[2] + ((int)(xb1[1] - 8))*yb[3] +
331+
((int)(xb0[2] - 8))*yb[4] + ((int)(xb1[2] - 8))*yb[5] +
332+
((int)(xb0[3] - 8))*yb[6] + ((int)(xb1[3] - 8))*yb[7] +
333+
((int)(xb0[4] - 8))*yb[8] + ((int)(xb1[4] - 8))*yb[9] +
334+
((int)(xb0[5] - 8))*yb[10] + ((int)(xb1[5] - 8))*yb[11] +
335+
((int)(xb0[6] - 8))*yb[12] + ((int)(xb1[6] - 8))*yb[13] +
336+
((int)(xb0[7] - 8))*yb[14] + ((int)(xb1[7] - 8))*yb[15]
337+
)*dx*dy;
338+
}
339+
}
340+
341+
// reduce
342+
__syncthreads();
343+
344+
for (int s = NT/2; s > 0; s >>= 1) {
345+
if (tid < s) {
346+
for (int ibr = 0; ibr < NR; ++ibr) {
347+
tmp[ibr][tid] += tmp[ibr][tid + s];
348+
}
349+
}
350+
__syncthreads();
351+
}
352+
353+
if (tid == 0) {
354+
for (int ibr = 0; ibr < NR; ++ibr) {
355+
const int ir = bid*NR + ibr;
356+
if (ir < nrows) {
357+
dst[ir] = tmp[ibr][0];
358+
}
359+
}
360+
}
361+
}
362+
277363
static void dequantize_row_q4_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
278364
const int nb = k / QK4_0;
279365
dequantize_block_q4_0<<<nb, 1, 0, stream>>>(vx, y);
@@ -316,9 +402,14 @@ static void dequantize_mul_mat_q4_0_cuda(const void * vx, const void * y, float
316402
// }
317403
// }
318404
// dequantize_mul_mat_q4_0<<<nrows, block_size, 0, stream>>>(vx, y, dst, ncols);
319-
const int block_size = 32;
320-
GGML_ASSERT(ncols % block_size == 0);
321-
dequantize_mul_mat_q4_0<block_size><<<nrows, block_size, 0, stream>>>(vx, y, dst, ncols);
405+
//const int block_size = 32;
406+
//GGML_ASSERT(ncols % block_size == 0);
407+
//dequantize_mul_mat_q4_0<block_size><<<nrows, block_size, 0, stream>>>(vx, y, dst, ncols);
408+
409+
const int NR = 1; // unroll rows (seems to not help)
410+
const int NT = 64; // number of thrads per row
411+
412+
dequantize_mul_mat_q4_0_test<NT, NR><<<(nrows + NR - 1)/NR, NT, 0, stream>>>(vx, y, dst, ncols, nrows);
322413
}
323414

324415
// TODO: optimize

0 commit comments

Comments
 (0)