@@ -365,71 +365,80 @@ kernel void kernel_rms_norm(
365
365
}
366
366
}
367
367
368
+ // putting them in the kernel cause a significant performance penalty
369
+ #define N_DST 4 // each SIMD group works on 4 rows
370
+ #define N_SIMDGROUP 2 // number of SIMD groups in a thread group
371
+ #define N_SIMDWIDTH 32 // assuming SIMD group size is 32
368
372
kernel void kernel_mul_mat_q4_0_f32 (
369
373
device const void * src0,
370
374
device const float * src1,
371
375
device float * dst,
372
376
constant int64_t & ne00,
373
377
constant int64_t & ne10,
374
378
constant int64_t & ne0,
375
- threadgroup float * sum [[threadgroup( 0 )]],
379
+ constant int64_t & ne01[[buffer( 4 )]],
376
380
uint2 tgpig[[threadgroup_position_in_grid]],
377
- uint2 tpitg[[thread_position_in_threadgroup ]],
378
- uint2 tptg[[threads_per_threadgroup ]]) {
381
+ uint tiisg[[thread_index_in_simdgroup ]],
382
+ uint sgitg[[simdgroup_index_in_threadgroup ]]) {
379
383
const int nb = ne00/QK4_0;
380
-
381
- const int64_t r0 = tgpig.x ;
382
- const int64_t r1 = tgpig.y ;
383
-
384
- device const block_q4_0 * x = (device const block_q4_0 *) src0 + r0*nb;
384
+ const int r0 = tgpig.x ;
385
+ const int r1 = tgpig.y ;
386
+ device const block_q4_0 * x = (device const block_q4_0 *) src0 + (r0 * N_SIMDGROUP + sgitg) * N_DST * nb;
385
387
device const float * y = (device const float *) src1 + r1*ne10;
388
+ block_q4_0 qb_curr, qb_next;
389
+ float4 y_curr[8 ]; // src1 vector cache
390
+ float sumf[N_DST]={0 .f }, all_sum;
391
+ thread float * yl=(thread float *)y_curr;
392
+
393
+ // bootstrap
394
+ qb_curr = x[tiisg];
395
+ // each thread in a SIMD group deals with 1 block.
396
+ for (int column = 0 ; column < nb / N_SIMDWIDTH; column++) {
397
+
398
+ for (int i = 0 ; i < QK4_0 / 4 ; i++) {
399
+ y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + column * QK4_0) + 4 * i));
400
+ }
386
401
387
- const int nth = tptg.x *tptg.y ;
388
- const int ith = tptg.y *tpitg.x + tpitg.y ;
389
-
390
- const int ix = tpitg.y /4 ; // 0 or 1
391
- const int iy = tpitg.y - 4 *ix; // 0...3
392
-
393
- const int first = 4 * iy;
394
-
395
- float sumf = 0 ;
402
+ for (int row = 0 ; row < N_DST; row++) {
403
+ // prefetch next x block
404
+ qb_next = x[tiisg + ((row + 1 ) % N_DST) * nb + (column + ((row + 1 ) / N_DST)) * N_SIMDWIDTH];
396
405
397
- for (int i = 2 *tpitg.x + ix; i < nb; i += 2 *tptg.x ) {
406
+ // calculate
407
+ float d = qb_curr.d ;
408
+ float2 acc = {0 .0f , 0 .0f };
409
+ for (int i = 0 ; i < 16 ; i++) {
410
+ acc[0 ] += yl[i] * (qb_curr.qs [i] & 0xF ) + yl[i+16 ] * (qb_curr.qs [i] >> 4 );
411
+ acc[1 ] += yl[i] + yl[i+16 ];
412
+ }
413
+ sumf[row] += d * (acc[0 ] - 8 .f *acc[1 ]);
414
+ qb_curr = qb_next;
415
+ }
416
+ }
398
417
399
- const float d = (float )x[i].d ;
418
+ for (int i = 0 ; i < QK4_0 / 4 ; i++) {
419
+ y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + (nb / N_SIMDWIDTH) * QK4_0) + 4 * i));
420
+ }
400
421
401
- device const uint8_t * xl = x[i].qs + first;
402
- device const float * yl = y + i * QK4_0 + first;
422
+ for (int row = 0 ; row < N_DST; row++) {
423
+ // prefetch next x block
424
+ qb_next = x[tiisg + ((row + 1 ) % N_DST) * nb + (nb / N_SIMDWIDTH + ((row + 1 ) / N_DST)) * N_SIMDWIDTH];
403
425
426
+ // calculate
427
+ float d = qb_curr.d ;
404
428
float2 acc = {0 .0f , 0 .0f };
405
-
406
- for (int j = 0 ; j < 4 ; ++j) {
407
-
408
- acc[0 ] += yl[j] * (xl[j] & 0xF ) + yl[j+16 ] * (xl[j] >> 4 );
409
- acc[1 ] += yl[j] + yl[j+16 ];
410
-
429
+ for (int i = 0 ; i < 16 ; i++) {
430
+ acc[0 ] += yl[i] * (qb_curr.qs [i] & 0xF ) + yl[i+16 ] * (qb_curr.qs [i] >> 4 );
431
+ acc[1 ] += yl[i] + yl[i+16 ];
411
432
}
433
+ if (tiisg < nb % N_SIMDWIDTH) {
434
+ sumf[row] += d * (acc[0 ] - 8 .f *acc[1 ]);
435
+ }
436
+ qb_curr = qb_next;
412
437
413
- sumf += d * (acc[0 ] - 8 .f *acc[1 ]);
414
- }
415
-
416
- sum[ith] = sumf;
417
-
418
- //
419
- // Accumulate the sum from all threads in the threadgroup
420
- //
421
- threadgroup_barrier (mem_flags::mem_threadgroup);
422
- if (ith%4 == 0 ) {
423
- sum[ith] += sum[ith+1 ] + sum[ith+2 ] + sum[ith+3 ];
424
- }
425
- threadgroup_barrier (mem_flags::mem_threadgroup);
426
- if (ith%16 == 0 ) {
427
- sum[ith] += sum[ith+4 ] + sum[ith+8 ] + sum[ith+12 ];
428
- }
429
- threadgroup_barrier (mem_flags::mem_threadgroup);
430
- if (ith == 0 ) {
431
- for (int i = 16 ; i < nth; i += 16 ) sum[0 ] += sum[i];
432
- dst[r1*ne0 + r0] = sum[0 ];
438
+ all_sum = simd_sum (sumf[row]);
439
+ if (tiisg == 0 && ((r0 * N_SIMDGROUP + sgitg) * N_DST + row) < ne01) {
440
+ dst[r1*ne0 + (r0 * N_SIMDGROUP + sgitg) * N_DST + row] = all_sum;
441
+ }
433
442
}
434
443
}
435
444
0 commit comments