@@ -384,27 +384,44 @@ kernel void kernel_rms_norm(
384
384
}
385
385
}
386
386
387
+ // function for calculate inner product between a q4_0 block and 32 floats (yl), sumy is SUM(yl[i])
388
+ float block_q_n_dot_y (block_q4_0 qb_curr, float sumy, thread float * yl) {
389
+ float d = qb_curr.d ;
390
+ float acc = sumy * -8 .f ;
391
+ for (int i = 0 ; i < 16 ; i+=2 ) {
392
+ acc += yl[i] * (qb_curr.qs [i / 2 ] & 0x000F ) + yl[i + 16 ] * (qb_curr.qs [i / 2 ] & 0x00F0 );
393
+ acc += yl[i + 1 ] * (qb_curr.qs [i / 2 ] & 0x0F00 ) + yl[i + 17 ] * (qb_curr.qs [i / 2 ] & 0xF000 );
394
+ }
395
+ return d * acc;
396
+ }
397
+
398
+ // function for calculate inner product between a q4_1 block and 32 floats (yl), sumy is SUM(yl[i])
399
+ float block_q_n_dot_y (block_q4_1 qb_curr, float sumy, thread float * yl) {
400
+ float d = qb_curr.d ;
401
+ float m = qb_curr.m ;
402
+ float acc = 0 .f ;
403
+ for (int i = 0 ; i < 16 ; i+=2 ) {
404
+ acc += yl[i] * (qb_curr.qs [i / 2 ] & 0x000F ) + yl[i + 16 ] * (qb_curr.qs [i / 2 ] & 0x00F0 );
405
+ acc += yl[i + 1 ] * (qb_curr.qs [i / 2 ] & 0x0F00 ) + yl[i + 17 ] * (qb_curr.qs [i / 2 ] & 0xF000 );
406
+ }
407
+ return d * acc + m * sumy;
408
+ }
409
+
387
410
// putting them in the kernel cause a significant performance penalty
388
411
#define N_DST 4 // each SIMD group works on 4 rows
389
412
#define N_SIMDGROUP 2 // number of SIMD groups in a thread group
390
413
#define N_SIMDWIDTH 32 // assuming SIMD group size is 32
391
- kernel void kernel_mul_mat_q4_0_f32 (
392
- device const void * src0,
393
- device const float * src1,
394
- device float * dst,
395
- constant int64_t & ne00,
396
- constant int64_t & ne10,
397
- constant int64_t & ne0,
398
- constant int64_t & ne01[[buffer(4 )]],
399
- uint2 tgpig[[threadgroup_position_in_grid]],
400
- uint tiisg[[thread_index_in_simdgroup]],
401
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
414
+
415
+ template <typename block_q_type>
416
+ void mul_vec_q_n_f32 (device const void * src0, device const float * src1, device float * dst,
417
+ int64_t ne00, int64_t ne10, int64_t ne0, int64_t ne01,
418
+ uint2 tgpig, uint tiisg, uint sgitg) {
402
419
const int nb = ne00/QK4_0;
403
420
const int r0 = tgpig.x ;
404
421
const int r1 = tgpig.y ;
405
- device const block_q4_0 * x = (device const block_q4_0 *) src0 + (r0 * N_SIMDGROUP + sgitg) * N_DST * nb;
422
+ device const block_q_type * x = (device const block_q_type *) src0 + (r0 * N_SIMDGROUP + sgitg) * N_DST * nb;
406
423
device const float * y = (device const float *) src1 + r1*ne10;
407
- block_q4_0 qb_curr, qb_next ;
424
+ block_q_type qb_curr;
408
425
float4 y_curr[8 ]; // src1 vector cache
409
426
float sumf[N_DST]={0 .f }, all_sum;
410
427
thread float * yl=(thread float *)y_curr;
@@ -419,25 +436,15 @@ kernel void kernel_mul_mat_q4_0_f32(
419
436
y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + column * QK4_0) + 4 * i));
420
437
sumy += y_curr[i][0 ] + y_curr[i][1 ] + y_curr[i][2 ] + y_curr[i][3 ];
421
438
}
422
- sumy *= (-8 .f );
423
439
// we don't right shift packed 4-bit weights, so we have to devide y by 16/256/4096 to conpensate this.
440
+ // this design is q4_0 and q4_1 centered, but I think most of the people use these two quantizations.
424
441
for (int i = 0 ; i < 32 ; i++) {
425
442
yl[i] *= pow (1 .f /16 .f , 2 * (i % 2 ) + i / 16 );
426
443
}
427
444
428
445
for (int row = 0 ; row < N_DST; row++) {
429
- // prefetch next x block
430
- qb_next = x[tiisg + ((row + 1 ) % N_DST) * nb + (column + ((row + 1 ) / N_DST)) * N_SIMDWIDTH];
431
-
432
- // calculate
433
- float d = qb_curr.d ;
434
- float acc = sumy;
435
- for (int i = 0 ; i < 16 ; i+=2 ) {
436
- acc += yl[i] * (qb_curr.qs [i / 2 ] & 0x000F ) + yl[i + 16 ] * (qb_curr.qs [i / 2 ] & 0x00F0 );
437
- acc += yl[i + 1 ] * (qb_curr.qs [i / 2 ] & 0x0F00 ) + yl[i + 17 ] * (qb_curr.qs [i / 2 ] & 0xF000 );
438
- }
439
- sumf[row] += d * acc;
440
- qb_curr = qb_next;
446
+ sumf[row] += block_q_n_dot_y (qb_curr, sumy, yl);
447
+ qb_curr = x[tiisg + ((row + 1 ) % N_DST) * nb + (column + ((row + 1 ) / N_DST)) * N_SIMDWIDTH];
441
448
}
442
449
}
443
450
@@ -449,32 +456,20 @@ kernel void kernel_mul_mat_q4_0_f32(
449
456
}
450
457
}
451
458
} else {
452
-
453
459
float sumy = 0 ;
454
460
for (int i = 0 ; i < QK4_0 / 4 ; i++) {
455
461
y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + (nb / N_SIMDWIDTH) * QK4_0) + 4 * i));
456
462
sumy += y_curr[i][0 ] + y_curr[i][1 ] + y_curr[i][2 ] + y_curr[i][3 ];
457
463
}
458
- sumy *= (-8 .f );
459
464
for (int i = 0 ; i < 32 ; i++) {
460
465
yl[i] *= pow (1 .f /16 .f , 2 * (i % 2 ) + i / 16 );
461
466
}
462
467
463
468
for (int row = 0 ; row < N_DST; row++) {
464
- // prefetch next x block
465
- qb_next = x[tiisg + ((row + 1 ) % N_DST) * nb + (nb / N_SIMDWIDTH + ((row + 1 ) / N_DST)) * N_SIMDWIDTH];
466
-
467
- // calculate
468
- float d = qb_curr.d ;
469
- float acc = sumy;
470
- for (int i = 0 ; i < 16 ; i+=2 ) {
471
- acc += yl[i] * (qb_curr.qs [i / 2 ] & 0x000F ) + yl[i + 16 ] * (qb_curr.qs [i / 2 ] & 0x00F0 );
472
- acc += yl[i + 1 ] * (qb_curr.qs [i / 2 ] & 0x0F00 ) + yl[i + 17 ] * (qb_curr.qs [i / 2 ] & 0xF000 );
473
- }
474
469
if (tiisg < nb % N_SIMDWIDTH) {
475
- sumf[row] += d * acc ;
470
+ sumf[row] += block_q_n_dot_y (qb_curr, sumy, yl) ;
476
471
}
477
- qb_curr = qb_next ;
472
+ qb_curr = x[tiisg + ((row + 1 ) % N_DST) * nb + (nb / N_SIMDWIDTH + ((row + 1 ) / N_DST)) * N_SIMDWIDTH] ;
478
473
479
474
all_sum = simd_sum (sumf[row]);
480
475
if (tiisg == 0 && ((r0 * N_SIMDGROUP + sgitg) * N_DST + row) < ne01) {
@@ -484,7 +479,7 @@ kernel void kernel_mul_mat_q4_0_f32(
484
479
}
485
480
}
486
481
487
- kernel void kernel_mul_mat_q4_1_f32 (
482
+ kernel void kernel_mul_mat_q4_0_f32 (
488
483
device const void * src0,
489
484
device const float * src1,
490
485
device float * dst,
@@ -495,89 +490,21 @@ kernel void kernel_mul_mat_q4_1_f32(
495
490
uint2 tgpig[[threadgroup_position_in_grid]],
496
491
uint tiisg[[thread_index_in_simdgroup]],
497
492
uint sgitg[[simdgroup_index_in_threadgroup]]) {
498
- const int nb = ne00/QK4_0;
499
- const int r0 = tgpig.x ;
500
- const int r1 = tgpig.y ;
501
- device const block_q4_1 * x = (device const block_q4_1 *) src0 + (r0 * N_SIMDGROUP + sgitg) * N_DST * nb;
502
- device const float * y = (device const float *) src1 + r1*ne10;
503
- block_q4_1 qb_curr, qb_next;
504
- float4 y_curr[8 ]; // src1 vector cache
505
- float sumf[N_DST]={0 .f }, all_sum;
506
- thread float * yl=(thread float *)y_curr;
507
-
508
- // bootstrap
509
- qb_curr = x[tiisg];
510
- // each thread in a SIMD group deals with 1 block.
511
- for (int column = 0 ; column < nb / N_SIMDWIDTH; column++) {
512
-
513
- float sumy = 0 ;
514
- for (int i = 0 ; i < QK4_0 / 4 ; i++) {
515
- y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + column * QK4_0) + 4 * i));
516
- sumy += y_curr[i][0 ] + y_curr[i][1 ] + y_curr[i][2 ] + y_curr[i][3 ];
517
- }
518
- // we don't right shift packed 4-bit weights, so we have to devide y by 16/256/4096 to conpensate this.
519
- for (int i = 0 ; i < 32 ; i++) {
520
- yl[i] *= pow (1 .f /16 .f , 2 * (i % 2 ) + i / 16 );
521
- }
522
-
523
- for (int row = 0 ; row < N_DST; row++) {
524
- // prefetch next x block
525
- qb_next = x[tiisg + ((row + 1 ) % N_DST) * nb + (column + ((row + 1 ) / N_DST)) * N_SIMDWIDTH];
526
-
527
- // calculate
528
- const float d = qb_curr.d ;
529
- const float m = qb_curr.m ;
530
- float acc = 0 .f ;
531
- for (int i = 0 ; i < 16 ; i+=2 ) {
532
- acc += yl[i] * (qb_curr.qs [i / 2 ] & 0x000F ) + yl[i + 16 ] * (qb_curr.qs [i / 2 ] & 0x00F0 );
533
- acc += yl[i + 1 ] * (qb_curr.qs [i / 2 ] & 0x0F00 ) + yl[i + 17 ] * (qb_curr.qs [i / 2 ] & 0xF000 );
534
- }
535
- sumf[row] += d * acc + m * sumy;
536
- qb_curr = qb_next;
537
- }
538
- }
539
-
540
- if (nb % N_SIMDWIDTH == 0 ) {
541
- for (int row = 0 ; row < N_DST; ++row) {
542
- all_sum = simd_sum (sumf[row]);
543
- if (tiisg == 0 && ((r0 * N_SIMDGROUP + sgitg) * N_DST + row) < ne01) {
544
- dst[r1*ne0 + (r0 * N_SIMDGROUP + sgitg) * N_DST + row] = all_sum;
545
- }
546
- }
547
- } else {
548
-
549
- float sumy = 0 ;
550
- for (int i = 0 ; i < QK4_0 / 4 ; i++) {
551
- y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + (nb / N_SIMDWIDTH) * QK4_0) + 4 * i));
552
- sumy += y_curr[i][0 ] + y_curr[i][1 ] + y_curr[i][2 ] + y_curr[i][3 ];
553
- }
554
- for (int i = 0 ; i < 32 ; i++) {
555
- yl[i] *= pow (1 .f /16 .f , 2 * (i % 2 ) + i / 16 );
556
- }
557
-
558
- for (int row = 0 ; row < N_DST; row++) {
559
- // prefetch next x block
560
- qb_next = x[tiisg + ((row + 1 ) % N_DST) * nb + (nb / N_SIMDWIDTH + ((row + 1 ) / N_DST)) * N_SIMDWIDTH];
561
-
562
- // calculate
563
- const float d = qb_curr.d ;
564
- const float m = qb_curr.m ;
565
- float acc = 0 .f ;
566
- for (int i = 0 ; i < 16 ; i+=2 ) {
567
- acc += yl[i] * (qb_curr.qs [i / 2 ] & 0x000F ) + yl[i + 16 ] * (qb_curr.qs [i / 2 ] & 0x00F0 );
568
- acc += yl[i + 1 ] * (qb_curr.qs [i / 2 ] & 0x0F00 ) + yl[i + 17 ] * (qb_curr.qs [i / 2 ] & 0xF000 );
569
- }
570
- if (tiisg < nb % N_SIMDWIDTH) {
571
- sumf[row] += d * acc + m * sumy;
572
- }
573
- qb_curr = qb_next;
493
+ mul_vec_q_n_f32<block_q4_0>(src0,src1,dst,ne00,ne10,ne0,ne01,tgpig,tiisg,sgitg);
494
+ }
574
495
575
- all_sum = simd_sum (sumf[row]);
576
- if (tiisg == 0 && ((r0 * N_SIMDGROUP + sgitg) * N_DST + row) < ne01) {
577
- dst[r1*ne0 + (r0 * N_SIMDGROUP + sgitg) * N_DST + row] = all_sum;
578
- }
579
- }
580
- }
496
+ kernel void kernel_mul_mat_q4_1_f32 (
497
+ device const void * src0,
498
+ device const float * src1,
499
+ device float * dst,
500
+ constant int64_t & ne00,
501
+ constant int64_t & ne10,
502
+ constant int64_t & ne0,
503
+ constant int64_t & ne01[[buffer(4 )]],
504
+ uint2 tgpig[[threadgroup_position_in_grid]],
505
+ uint tiisg[[thread_index_in_simdgroup]],
506
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
507
+ mul_vec_q_n_f32<block_q4_1>(src0,src1,dst,ne00,ne10,ne0,ne01,tgpig,tiisg,sgitg);
581
508
}
582
509
583
510
kernel void kernel_mul_mat_f16_f32 (
0 commit comments