@@ -18,6 +18,12 @@ typedef struct {
18
18
uint8_t qs[QK4_1 / 2 ]; // nibbles / quants
19
19
} block_q4_1;
20
20
21
+ #define QK8_0 32
22
+ typedef struct {
23
+ half d; // delta
24
+ int8_t qs[QK8_0]; // quants
25
+ } block_q8_0;
26
+
21
27
kernel void kernel_add (
22
28
device const float * src0,
23
29
device const float * src1,
@@ -357,7 +363,7 @@ void mul_vec_q_n_f32(device const void * src0, device const float * src1, device
357
363
const int first_row = (r0 * nsg + sgitg) * nr;
358
364
const uint offset0 = first_row * nb + im/gqa*(nb*ne0);
359
365
device const block_q_type * x = (device const block_q_type *) src0 + offset0;
360
- device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
366
+ device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
361
367
float yl[16 ]; // src1 vector cache
362
368
float sumf[nr]={0 .f };
363
369
@@ -429,6 +435,68 @@ kernel void kernel_mul_mat_q4_1_f32(
429
435
mul_vec_q_n_f32<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
430
436
}
431
437
438
+ kernel void kernel_mul_mat_q8_0_f32 (
439
+ device const void * src0,
440
+ device const float * src1,
441
+ device float * dst,
442
+ constant int64_t & ne00,
443
+ constant int64_t & ne01[[buffer(4 )]],
444
+ constant int64_t & ne02[[buffer(5 )]],
445
+ constant int64_t & ne10[[buffer(9 )]],
446
+ constant int64_t & ne12[[buffer(11 )]],
447
+ constant int64_t & ne0[[buffer(15 )]],
448
+ constant int64_t & ne1[[buffer(16 )]],
449
+ constant uint & gqa[[buffer(17 )]],
450
+ uint3 tgpig[[threadgroup_position_in_grid]],
451
+ uint tiisg[[thread_index_in_simdgroup]],
452
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
453
+ const int nr = N_DST;
454
+ const int nsg = N_SIMDGROUP;
455
+ const int nw = N_SIMDWIDTH;
456
+
457
+ const int nb = ne00/QK8_0;
458
+ const int r0 = tgpig.x ;
459
+ const int r1 = tgpig.y ;
460
+ const int im = tgpig.z ;
461
+ const int first_row = (r0 * nsg + sgitg) * nr;
462
+ const uint offset0 = first_row * nb + im/gqa*(nb*ne0);
463
+ device const block_q8_0 * x = (device const block_q8_0 *) src0 + offset0;
464
+ device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
465
+
466
+ float yl[16 ];
467
+ float sumf[nr]={0 .f };
468
+
469
+ const int ix = tiisg/2 ;
470
+ const int il = tiisg%2 ;
471
+
472
+ device const float * yb = y + ix * QK8_0 + 16 *il;
473
+
474
+ // each thread in a SIMD group deals with half a block.
475
+ for (int ib = ix; ib < nb; ib += nw/2 ) {
476
+ for (int i = 0 ; i < 16 ; ++i) {
477
+ yl[i] = yb[i];
478
+ }
479
+
480
+ for (int row = 0 ; row < nr; row++) {
481
+ device const int8_t * qs = x[ib+row*nb].qs + 16 *il;
482
+ float sumq = 0 .f ;
483
+ for (int iq = 0 ; iq < 16 ; ++iq) {
484
+ sumq += qs[iq] * yl[iq];
485
+ }
486
+ sumf[row] += sumq*x[ib+row*nb].d ;
487
+ }
488
+
489
+ yb += QK8_0 * 16 ;
490
+ }
491
+
492
+ for (int row = 0 ; row < nr; ++row) {
493
+ const float tot = simd_sum (sumf[row]);
494
+ if (tiisg == 0 && first_row + row < ne01) {
495
+ dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot;
496
+ }
497
+ }
498
+ }
499
+
432
500
kernel void kernel_mul_mat_f16_f32 (
433
501
device const char * src0,
434
502
device const char * src1,
@@ -480,7 +548,6 @@ kernel void kernel_mul_mat_f16_f32(
480
548
}
481
549
}
482
550
483
-
484
551
kernel void kernel_alibi_f32 (
485
552
device const float * src0,
486
553
device float * dst,
@@ -1621,12 +1688,12 @@ template <typename type4x4>
1621
1688
void dequantize_q4_0 (device const block_q4_0 *xb, short il, thread type4x4 & reg) {
1622
1689
device const uint16_t * qs = ((device const uint16_t *)xb + 1 );
1623
1690
const half d = il ? (xb->d / 16 .h ) : xb->d ;
1624
- const half m = il ? (-8 .h * 16 .h ) : -8 .h ;
1691
+ const half m = il ? ( -8 .h * 16 .h ) : -8 .h ;
1625
1692
const ushort mask0 = il ? 0x00F0 : 0x000F ;
1626
1693
const ushort mask1 = il ? 0xF000 : 0x0F00 ;
1627
1694
1628
1695
for (int i=0 ;i<8 ;i++) {
1629
- reg[i/2 ][2 *(i%2 )] = (((qs[i] & mask0)) + m) * d;
1696
+ reg[i/2 ][2 *(i%2 )] = (((qs[i] & mask0) ) + m) * d;
1630
1697
reg[i/2 ][2 *(i%2 )+1 ] = (((qs[i] & mask1) >> 8 ) + m) * d;
1631
1698
}
1632
1699
}
@@ -1640,11 +1707,21 @@ void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg
1640
1707
const ushort mask1 = il ? 0xF000 : 0x0F00 ;
1641
1708
1642
1709
for (int i=0 ;i<8 ;i++) {
1643
- reg[i/2 ][2 *(i%2 )] = (((qs[i] & mask0)) * d) + m;
1710
+ reg[i/2 ][2 *(i%2 )] = (((qs[i] & mask0) ) * d) + m;
1644
1711
reg[i/2 ][2 *(i%2 )+1 ] = (((qs[i] & mask1) >> 8 ) * d) + m;
1645
1712
}
1646
1713
}
1647
1714
1715
+ template <typename type4x4>
1716
+ void dequantize_q8_0 (device const block_q8_0 *xb, short il, thread type4x4 & reg) {
1717
+ device const int8_t * qs = ((device const int8_t *)xb->qs );
1718
+ const half d = xb->d ;
1719
+
1720
+ for (int i=0 ;i<16 ;i++) {
1721
+ reg[i/4 ][i%4 ] = (qs[i + 16 *il] * d);
1722
+ }
1723
+ }
1724
+
1648
1725
template <typename type4x4>
1649
1726
void dequantize_q2_K (device const block_q2_K *xb, short il, thread type4x4 & reg) {
1650
1727
const half d = xb->d ;
@@ -1947,9 +2024,10 @@ kernel void kernel_mul_mm(device const uchar * src0,
1947
2024
typedef void (get_rows_t )(device const void *, device const int *, device float *, constant int64_t &, \
1948
2025
constant uint64_t &, constant uint64_t &, uint, uint, uint);
1949
2026
1950
- template [[host_name(" kernel_get_rows_f16" )]] kernel get_rows_t kernel_get_rows<half4x4, 1 , dequantize_f16>;
2027
+ template [[host_name(" kernel_get_rows_f16" )]] kernel get_rows_t kernel_get_rows<half4x4, 1 , dequantize_f16>;
1951
2028
template [[host_name(" kernel_get_rows_q4_0" )]] kernel get_rows_t kernel_get_rows<block_q4_0, 2 , dequantize_q4_0>;
1952
2029
template [[host_name(" kernel_get_rows_q4_1" )]] kernel get_rows_t kernel_get_rows<block_q4_1, 2 , dequantize_q4_1>;
2030
+ template [[host_name(" kernel_get_rows_q8_0" )]] kernel get_rows_t kernel_get_rows<block_q8_0, 2 , dequantize_q8_0>;
1953
2031
template [[host_name(" kernel_get_rows_q2_K" )]] kernel get_rows_t kernel_get_rows<block_q2_K, QK_NL, dequantize_q2_K>;
1954
2032
template [[host_name(" kernel_get_rows_q3_K" )]] kernel get_rows_t kernel_get_rows<block_q3_K, QK_NL, dequantize_q3_K>;
1955
2033
template [[host_name(" kernel_get_rows_q4_K" )]] kernel get_rows_t kernel_get_rows<block_q4_K, QK_NL, dequantize_q4_K>;
@@ -1960,9 +2038,10 @@ typedef void (mat_mm_t)(device const uchar *, device const float *, device float
1960
2038
constant int64_t &, constant int64_t &, constant int64_t &, constant int64_t &, \
1961
2039
constant int64_t &, constant int64_t &, constant uint &, threadgroup uchar *, uint3, uint, uint);
1962
2040
1963
- template [[host_name(" kernel_mul_mm_f16_f32" )]] kernel mat_mm_t kernel_mul_mm<half4x4, 1 , dequantize_f16>;
2041
+ template [[host_name(" kernel_mul_mm_f16_f32" )]] kernel mat_mm_t kernel_mul_mm<half4x4, 1 , dequantize_f16>;
1964
2042
template [[host_name(" kernel_mul_mm_q4_0_f32" )]] kernel mat_mm_t kernel_mul_mm<block_q4_0, 2 , dequantize_q4_0>;
1965
2043
template [[host_name(" kernel_mul_mm_q4_1_f32" )]] kernel mat_mm_t kernel_mul_mm<block_q4_1, 2 , dequantize_q4_1>;
2044
+ template [[host_name(" kernel_mul_mm_q8_0_f32" )]] kernel mat_mm_t kernel_mul_mm<block_q8_0, 2 , dequantize_q8_0>;
1966
2045
template [[host_name(" kernel_mul_mm_q2_K_f32" )]] kernel mat_mm_t kernel_mul_mm<block_q2_K, QK_NL, dequantize_q2_K>;
1967
2046
template [[host_name(" kernel_mul_mm_q3_K_f32" )]] kernel mat_mm_t kernel_mul_mm<block_q3_K, QK_NL, dequantize_q3_K>;
1968
2047
template [[host_name(" kernel_mul_mm_q4_K_f32" )]] kernel mat_mm_t kernel_mul_mm<block_q4_K, QK_NL, dequantize_q4_K>;
0 commit comments