@@ -395,49 +395,63 @@ kernel void kernel_mul_mat_q4_0_f32(
395
395
// each thread in a SIMD group deals with 1 block.
396
396
for (int column = 0 ; column < nb / N_SIMDWIDTH; column++) {
397
397
398
+ float sumy = 0 ;
398
399
for (int i = 0 ; i < QK4_0 / 4 ; i++) {
399
400
y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + column * QK4_0) + 4 * i));
401
+ sumy += y_curr[i][0 ] + y_curr[i][1 ] + y_curr[i][2 ] + y_curr[i][3 ];
400
402
}
403
+ sumy *= (-8 .f );
401
404
402
405
for (int row = 0 ; row < N_DST; row++) {
403
406
// prefetch next x block
404
407
qb_next = x[tiisg + ((row + 1 ) % N_DST) * nb + (column + ((row + 1 ) / N_DST)) * N_SIMDWIDTH];
405
408
406
409
// calculate
407
410
float d = qb_curr.d ;
408
- float2 acc = { 0 . 0f , 0 . 0f } ;
411
+ float acc = sumy ;
409
412
for (int i = 0 ; i < 16 ; i++) {
410
- acc[0 ] += yl[i] * (qb_curr.qs [i] & 0xF ) + yl[i+16 ] * (qb_curr.qs [i] >> 4 );
411
- acc[1 ] += yl[i] + yl[i+16 ];
413
+ acc += yl[i] * (qb_curr.qs [i] & 0xF ) + yl[i+16 ] * (qb_curr.qs [i] >> 4 );
412
414
}
413
- sumf[row] += d * ( acc[ 0 ] - 8 . f *acc[ 1 ]) ;
415
+ sumf[row] += d * acc;
414
416
qb_curr = qb_next;
415
417
}
416
418
}
417
419
418
- for (int i = 0 ; i < QK4_0 / 4 ; i++) {
419
- y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + (nb / N_SIMDWIDTH) * QK4_0) + 4 * i));
420
- }
421
-
422
- for (int row = 0 ; row < N_DST; row++) {
423
- // prefetch next x block
424
- qb_next = x[tiisg + ((row + 1 ) % N_DST) * nb + (nb / N_SIMDWIDTH + ((row + 1 ) / N_DST)) * N_SIMDWIDTH];
425
-
426
- // calculate
427
- float d = qb_curr.d ;
428
- float2 acc = {0 .0f , 0 .0f };
429
- for (int i = 0 ; i < 16 ; i++) {
430
- acc[0 ] += yl[i] * (qb_curr.qs [i] & 0xF ) + yl[i+16 ] * (qb_curr.qs [i] >> 4 );
431
- acc[1 ] += yl[i] + yl[i+16 ];
420
+ if (nb % N_SIMDWIDTH == 0 ) {
421
+ for (int row = 0 ; row < N_DST; ++row) {
422
+ all_sum = simd_sum (sumf[row]);
423
+ if (tiisg == 0 && ((r0 * N_SIMDGROUP + sgitg) * N_DST + row) < ne01) {
424
+ dst[r1*ne0 + (r0 * N_SIMDGROUP + sgitg) * N_DST + row] = all_sum;
425
+ }
432
426
}
433
- if (tiisg < nb % N_SIMDWIDTH) {
434
- sumf[row] += d * (acc[0 ] - 8 .f *acc[1 ]);
427
+ } else {
428
+
429
+ float sumy = 0 ;
430
+ for (int i = 0 ; i < QK4_0 / 4 ; i++) {
431
+ y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + (nb / N_SIMDWIDTH) * QK4_0) + 4 * i));
432
+ sumy += y_curr[i][0 ] + y_curr[i][1 ] + y_curr[i][2 ] + y_curr[i][3 ];
435
433
}
436
- qb_curr = qb_next;
434
+ sumy *= (-8 .f );
435
+
436
+ for (int row = 0 ; row < N_DST; row++) {
437
+ // prefetch next x block
438
+ qb_next = x[tiisg + ((row + 1 ) % N_DST) * nb + (nb / N_SIMDWIDTH + ((row + 1 ) / N_DST)) * N_SIMDWIDTH];
439
+
440
+ // calculate
441
+ float d = qb_curr.d ;
442
+ float acc = sumy;
443
+ for (int i = 0 ; i < 16 ; i++) {
444
+ acc += yl[i] * (qb_curr.qs [i] & 0xF ) + yl[i+16 ] * (qb_curr.qs [i] >> 4 );
445
+ }
446
+ if (tiisg < nb % N_SIMDWIDTH) {
447
+ sumf[row] += d * acc;
448
+ }
449
+ qb_curr = qb_next;
437
450
438
- all_sum = simd_sum (sumf[row]);
439
- if (tiisg == 0 && ((r0 * N_SIMDGROUP + sgitg) * N_DST + row) < ne01) {
440
- dst[r1*ne0 + (r0 * N_SIMDGROUP + sgitg) * N_DST + row] = all_sum;
451
+ all_sum = simd_sum (sumf[row]);
452
+ if (tiisg == 0 && ((r0 * N_SIMDGROUP + sgitg) * N_DST + row) < ne01) {
453
+ dst[r1*ne0 + (r0 * N_SIMDGROUP + sgitg) * N_DST + row] = all_sum;
454
+ }
441
455
}
442
456
}
443
457
}
@@ -449,65 +463,83 @@ kernel void kernel_mul_mat_q4_1_f32(
449
463
constant int64_t & ne00,
450
464
constant int64_t & ne10,
451
465
constant int64_t & ne0,
452
- threadgroup float * sum [[threadgroup( 0 )]],
466
+ constant int64_t & ne01[[buffer( 4 )]],
453
467
uint2 tgpig[[threadgroup_position_in_grid]],
454
- uint2 tpitg[[thread_position_in_threadgroup]],
455
- uint2 tptg[[threads_per_threadgroup]]) {
456
- const int nb = ne00/QK4_1;
457
-
458
- const int64_t r0 = tgpig.x ;
459
- const int64_t r1 = tgpig.y ;
460
-
461
- device const block_q4_1 * x = (device const block_q4_1 *) src0 + r0*nb;
468
+ uint tiisg[[thread_index_in_simdgroup]],
469
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
470
+ const int nb = ne00/QK4_0;
471
+ const int r0 = tgpig.x ;
472
+ const int r1 = tgpig.y ;
473
+ device const block_q4_1 * x = (device const block_q4_1 *) src0 + (r0 * N_SIMDGROUP + sgitg) * N_DST * nb;
462
474
device const float * y = (device const float *) src1 + r1*ne10;
475
+ block_q4_1 qb_curr, qb_next;
476
+ float4 y_curr[8 ]; // src1 vector cache
477
+ float sumf[N_DST]={0 .f }, all_sum;
478
+ thread float * yl=(thread float *)y_curr;
463
479
464
- const uint nth = tptg.x *tptg.y ;
465
- const uint ith = tptg.y *tpitg.x + tpitg.y ;
466
-
467
- const int ix = tpitg.y /4 ; // 0 or 1
468
- const int iy = tpitg.y - 4 *ix; // 0...3
469
-
470
- const int first = 4 * iy;
471
-
472
- float sumf = 0 ;
473
-
474
- for (int i = 2 *tpitg.x + ix; i < nb; i += 2 *tptg.x ) {
475
-
476
- const float d = (float )x[i].d ;
477
- const float m = (float )x[i].m ;
480
+ // bootstrap
481
+ qb_curr = x[tiisg];
482
+ // each thread in a SIMD group deals with 1 block.
483
+ for (int column = 0 ; column < nb / N_SIMDWIDTH; column++) {
478
484
479
- device const uint8_t * xl = x[i].qs + first;
480
- device const float * yl = y + i * QK4_1 + first;
485
+ float sumy = 0 ;
486
+ for (int i = 0 ; i < QK4_0 / 4 ; i++) {
487
+ y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + column * QK4_0) + 4 * i));
488
+ sumy += y_curr[i][0 ] + y_curr[i][1 ] + y_curr[i][2 ] + y_curr[i][3 ];
489
+ }
481
490
482
- float2 acc = {0 .0f , 0 .0f };
491
+ for (int row = 0 ; row < N_DST; row++) {
492
+ // prefetch next x block
493
+ qb_next = x[tiisg + ((row + 1 ) % N_DST) * nb + (column + ((row + 1 ) / N_DST)) * N_SIMDWIDTH];
483
494
484
- for (int j = 0 ; j < 4 ; ++j) {
495
+ // calculate
496
+ const float d = qb_curr.d ;
497
+ const float m = qb_curr.m ;
498
+ float acc = 0 .f ;
499
+ for (int i = 0 ; i < 16 ; i++) {
500
+ acc += yl[i] * (qb_curr.qs [i] & 0xF ) + yl[i+16 ] * (qb_curr.qs [i] >> 4 );
501
+ }
502
+ sumf[row] += d * acc + m * sumy;
503
+ qb_curr = qb_next;
504
+ }
505
+ }
485
506
486
- acc[0 ] += yl[j+ 0 ] * (d * (xl[j] & 0xF ) + m);
487
- acc[1 ] += yl[j+16 ] * (d * (xl[j] >> 4 ) + m);
507
+ if (nb % N_SIMDWIDTH == 0 ) {
508
+ for (int row = 0 ; row < N_DST; ++row) {
509
+ all_sum = simd_sum (sumf[row]);
510
+ if (tiisg == 0 && ((r0 * N_SIMDGROUP + sgitg) * N_DST + row) < ne01) {
511
+ dst[r1*ne0 + (r0 * N_SIMDGROUP + sgitg) * N_DST + row] = all_sum;
512
+ }
513
+ }
514
+ } else {
488
515
516
+ float sumy = 0 ;
517
+ for (int i = 0 ; i < QK4_0 / 4 ; i++) {
518
+ y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + (nb / N_SIMDWIDTH) * QK4_0) + 4 * i));
519
+ sumy += y_curr[i][0 ] + y_curr[i][1 ] + y_curr[i][2 ] + y_curr[i][3 ];
489
520
}
490
521
491
- sumf += acc[0 ] + acc[1 ];
492
- }
522
+ for (int row = 0 ; row < N_DST; row++) {
523
+ // prefetch next x block
524
+ qb_next = x[tiisg + ((row + 1 ) % N_DST) * nb + (nb / N_SIMDWIDTH + ((row + 1 ) / N_DST)) * N_SIMDWIDTH];
493
525
494
- sum[ith] = sumf;
526
+ // calculate
527
+ const float d = qb_curr.d ;
528
+ const float m = qb_curr.m ;
529
+ float acc = 0 .f ;
530
+ for (int i = 0 ; i < 16 ; i++) {
531
+ acc += yl[i] * (qb_curr.qs [i] & 0xF ) + yl[i+16 ] * (qb_curr.qs [i] >> 4 );
532
+ }
533
+ if (tiisg < nb % N_SIMDWIDTH) {
534
+ sumf[row] += d * acc + m * sumy;
535
+ }
536
+ qb_curr = qb_next;
495
537
496
- //
497
- // Accumulate the sum from all threads in the threadgroup
498
- //
499
- threadgroup_barrier (mem_flags::mem_threadgroup);
500
- if (ith%4 == 0 ) {
501
- sum[ith] += sum[ith+1 ] + sum[ith+2 ] + sum[ith+3 ];
502
- }
503
- threadgroup_barrier (mem_flags::mem_threadgroup);
504
- if (ith%16 == 0 ) {
505
- sum[ith] += sum[ith+4 ] + sum[ith+8 ] + sum[ith+12 ];
506
- }
507
- threadgroup_barrier (mem_flags::mem_threadgroup);
508
- if (ith == 0 ) {
509
- for (uint i = 16 ; i < nth; i += 16 ) sum[0 ] += sum[i];
510
- dst[r1*ne0 + r0] = sum[0 ];
538
+ all_sum = simd_sum (sumf[row]);
539
+ if (tiisg == 0 && ((r0 * N_SIMDGROUP + sgitg) * N_DST + row) < ne01) {
540
+ dst[r1*ne0 + (r0 * N_SIMDGROUP + sgitg) * N_DST + row] = all_sum;
541
+ }
542
+ }
511
543
}
512
544
}
513
545
0 commit comments