@@ -363,7 +363,7 @@ void mul_vec_q_n_f32(device const void * src0, device const float * src1, device
363
363
const int first_row = (r0 * nsg + sgitg) * nr;
364
364
const uint offset0 = first_row * nb + im/gqa*(nb*ne0);
365
365
device const block_q_type * x = (device const block_q_type *) src0 + offset0;
366
- device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
366
+ device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
367
367
float yl[16 ]; // src1 vector cache
368
368
float sumf[nr]={0 .f };
369
369
@@ -435,6 +435,68 @@ kernel void kernel_mul_mat_q4_1_f32(
435
435
mul_vec_q_n_f32<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
436
436
}
437
437
438
+ kernel void kernel_mul_mat_q8_0_f32 (
439
+ device const void * src0,
440
+ device const float * src1,
441
+ device float * dst,
442
+ constant int64_t & ne00,
443
+ constant int64_t & ne01[[buffer(4 )]],
444
+ constant int64_t & ne02[[buffer(5 )]],
445
+ constant int64_t & ne10[[buffer(9 )]],
446
+ constant int64_t & ne12[[buffer(11 )]],
447
+ constant int64_t & ne0[[buffer(15 )]],
448
+ constant int64_t & ne1[[buffer(16 )]],
449
+ constant uint & gqa[[buffer(17 )]],
450
+ uint3 tgpig[[threadgroup_position_in_grid]],
451
+ uint tiisg[[thread_index_in_simdgroup]],
452
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
453
+ const int nr = N_DST;
454
+ const int nsg = N_SIMDGROUP;
455
+ const int nw = N_SIMDWIDTH;
456
+
457
+ const int nb = ne00/QK8_0;
458
+ const int r0 = tgpig.x ;
459
+ const int r1 = tgpig.y ;
460
+ const int im = tgpig.z ;
461
+ const int first_row = (r0 * nsg + sgitg) * nr;
462
+ const uint offset0 = first_row * nb + im/gqa*(nb*ne0);
463
+ device const block_q8_0 * x = (device const block_q8_0 *) src0 + offset0;
464
+ device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
465
+
466
+ float yl[16 ];
467
+ float sumf[nr]={0 .f };
468
+
469
+ const int ix = tiisg/2 ;
470
+ const int il = tiisg%2 ;
471
+
472
+ device const float * yb = y + ix * QK8_0 + 16 *il;
473
+
474
+ // each thread in a SIMD group deals with half a block.
475
+ for (int ib = ix; ib < nb; ib += nw/2 ) {
476
+ for (int i = 0 ; i < 16 ; ++i) {
477
+ yl[i] = yb[i];
478
+ }
479
+
480
+ for (int row = 0 ; row < nr; row++) {
481
+ device const int8_t * qs = x[ib+row*nb].qs + 16 *il;
482
+ float sumq = 0 .f ;
483
+ for (int iq = 0 ; iq < 16 ; ++iq) {
484
+ sumq += qs[iq] * yl[iq];
485
+ }
486
+ sumf[row] += sumq*x[ib+row*nb].d ;
487
+ }
488
+
489
+ yb += QK8_0 * 16 ;
490
+ }
491
+
492
+ for (int row = 0 ; row < nr; ++row) {
493
+ const float tot = simd_sum (sumf[row]);
494
+ if (tiisg == 0 && first_row + row < ne01) {
495
+ dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot;
496
+ }
497
+ }
498
+ }
499
+
438
500
kernel void kernel_mul_mat_f16_f32 (
439
501
device const char * src0,
440
502
device const char * src1,
@@ -486,7 +548,6 @@ kernel void kernel_mul_mat_f16_f32(
486
548
}
487
549
}
488
550
489
-
490
551
kernel void kernel_alibi_f32 (
491
552
device const float * src0,
492
553
device float * dst,
@@ -1653,7 +1714,7 @@ void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg
1653
1714
1654
1715
template <typename type4x4>
1655
1716
void dequantize_q8_0 (device const block_q8_0 *xb, short il, thread type4x4 & reg) {
1656
- device const uint8_t * qs = ((device const uint8_t *)xb->qs );
1717
+ device const int8_t * qs = ((device const int8_t *)xb->qs );
1657
1718
const half d = xb->d ;
1658
1719
1659
1720
for (int i=0 ;i<16 ;i++) {
0 commit comments