@@ -274,6 +274,92 @@ template <int block_size> static __global__ void dequantize_mul_mat_q4_0(const v
274
274
}
275
275
}
276
276
277
+ template <int NT, int NR> static __global__ void dequantize_mul_mat_q4_0_test (const void * vx, const void * vy, float * dst, const int ncols, const int nrows) {
278
+ const block_q4_0 * x = (const block_q4_0 *) vx;
279
+ const block_q8_0 * y = (const block_q8_0 *) vy;
280
+
281
+ const int bid = blockIdx .x ;
282
+ const int tid = threadIdx .x ;
283
+
284
+ __shared__ float tmp[NR][NT];
285
+ for (int i = 0 ; i < NR; ++i) {
286
+ tmp[i][tid] = 0 .0f ;
287
+ }
288
+
289
+ const int nbc = (ncols + 16 *NT - 1 )/(16 *NT);
290
+ const int nbm = ncols/QK8_0;
291
+
292
+ uint64_t xa0;
293
+ uint64_t xa1;
294
+
295
+ const int8_t * xb0 = (const int8_t *) &xa0;
296
+ const int8_t * xb1 = (const int8_t *) &xa1;
297
+
298
+ for (int ibc = 0 ; ibc < nbc; ++ibc) {
299
+ const int iyb = (ibc*(16 *NT) + 16 *tid)/QK8_0;
300
+ const int iyq = (ibc*(16 *NT) + 16 *tid)%QK8_0;
301
+
302
+ if (iyb >= nbm) {
303
+ continue ;
304
+ }
305
+
306
+ const int8_t * yb = (const int8_t *) &y[iyb].qs [iyq];
307
+
308
+ const float dy = y[iyb].d ;
309
+
310
+ for (int ibr = 0 ; ibr < NR; ++ibr) {
311
+ const int ir = bid*NR + ibr;
312
+ if (ir >= nrows) {
313
+ continue ;
314
+ }
315
+
316
+ // block offset
317
+ const int ixo = (ir*ncols)/QK4_0 + iyb;
318
+
319
+ memcpy (&xa0, &x[ixo].qs [iyq/2 + 0 ], sizeof (uint64_t ));
320
+ xa1 = xa0;
321
+
322
+ xa0 = (xa0 ) & 0x0F0F0F0F0F0F0F0F ;
323
+ xa1 = (xa1 >> 4 ) & 0x0F0F0F0F0F0F0F0F ;
324
+
325
+ const float dx = x[ixo].d ;
326
+
327
+ // the (int) cast is probably unnecessary, but just to make sure the result is accumulated in 32 bits
328
+ tmp[ibr][tid] += (
329
+ ((int )(xb0[0 ] - 8 ))*yb[0 ] + ((int )(xb1[0 ] - 8 ))*yb[1 ] +
330
+ ((int )(xb0[1 ] - 8 ))*yb[2 ] + ((int )(xb1[1 ] - 8 ))*yb[3 ] +
331
+ ((int )(xb0[2 ] - 8 ))*yb[4 ] + ((int )(xb1[2 ] - 8 ))*yb[5 ] +
332
+ ((int )(xb0[3 ] - 8 ))*yb[6 ] + ((int )(xb1[3 ] - 8 ))*yb[7 ] +
333
+ ((int )(xb0[4 ] - 8 ))*yb[8 ] + ((int )(xb1[4 ] - 8 ))*yb[9 ] +
334
+ ((int )(xb0[5 ] - 8 ))*yb[10 ] + ((int )(xb1[5 ] - 8 ))*yb[11 ] +
335
+ ((int )(xb0[6 ] - 8 ))*yb[12 ] + ((int )(xb1[6 ] - 8 ))*yb[13 ] +
336
+ ((int )(xb0[7 ] - 8 ))*yb[14 ] + ((int )(xb1[7 ] - 8 ))*yb[15 ]
337
+ )*dx*dy;
338
+ }
339
+ }
340
+
341
+ // reduce
342
+ __syncthreads ();
343
+
344
+ for (int s = NT/2 ; s > 0 ; s >>= 1 ) {
345
+ if (tid < s) {
346
+ for (int ibr = 0 ; ibr < NR; ++ibr) {
347
+ tmp[ibr][tid] += tmp[ibr][tid + s];
348
+ }
349
+ }
350
+ __syncthreads ();
351
+ }
352
+
353
+ if (tid == 0 ) {
354
+ for (int ibr = 0 ; ibr < NR; ++ibr) {
355
+ const int ir = bid*NR + ibr;
356
+ if (ir < nrows) {
357
+ dst[ir] = tmp[ibr][0 ];
358
+ }
359
+ }
360
+ }
361
+ }
362
+
277
363
static void dequantize_row_q4_0_cuda (const void * vx, float * y, int k, cudaStream_t stream) {
278
364
const int nb = k / QK4_0;
279
365
dequantize_block_q4_0<<<nb, 1 , 0 , stream>>> (vx, y);
@@ -316,9 +402,14 @@ static void dequantize_mul_mat_q4_0_cuda(const void * vx, const void * y, float
316
402
// }
317
403
// }
318
404
// dequantize_mul_mat_q4_0<<<nrows, block_size, 0, stream>>>(vx, y, dst, ncols);
319
- const int block_size = 32 ;
320
- GGML_ASSERT (ncols % block_size == 0 );
321
- dequantize_mul_mat_q4_0<block_size><<<nrows, block_size, 0 , stream>>> (vx, y, dst, ncols);
405
+ // const int block_size = 32;
406
+ // GGML_ASSERT(ncols % block_size == 0);
407
+ // dequantize_mul_mat_q4_0<block_size><<<nrows, block_size, 0, stream>>>(vx, y, dst, ncols);
408
+
409
+ const int NR = 1 ; // unroll rows (seems to not help)
410
+ const int NT = 64 ; // number of thrads per row
411
+
412
+ dequantize_mul_mat_q4_0_test<NT, NR><<<(nrows + NR - 1 )/NR, NT, 0 , stream>>> (vx, y, dst, ncols, nrows);
322
413
}
323
414
324
415
// TODO: optimize
0 commit comments