@@ -18,6 +18,12 @@ typedef struct {
18
18
uint8_t qs[QK4_1 / 2 ]; // nibbles / quants
19
19
} block_q4_1;
20
20
21
+ #define QK8_0 32
22
+ typedef struct {
23
+ half d; // delta
24
+ int8_t qs[QK8_0]; // quants
25
+ } block_q8_0;
26
+
21
27
kernel void kernel_add (
22
28
device const float * src0,
23
29
device const float * src1,
@@ -1621,12 +1627,12 @@ template <typename type4x4>
1621
1627
void dequantize_q4_0 (device const block_q4_0 *xb, short il, thread type4x4 & reg) {
1622
1628
device const uint16_t * qs = ((device const uint16_t *)xb + 1 );
1623
1629
const half d = il ? (xb->d / 16 .h ) : xb->d ;
1624
- const half m = il ? (-8 .h * 16 .h ) : -8 .h ;
1630
+ const half m = il ? ( -8 .h * 16 .h ) : -8 .h ;
1625
1631
const ushort mask0 = il ? 0x00F0 : 0x000F ;
1626
1632
const ushort mask1 = il ? 0xF000 : 0x0F00 ;
1627
1633
1628
1634
for (int i=0 ;i<8 ;i++) {
1629
- reg[i/2 ][2 *(i%2 )] = (((qs[i] & mask0)) + m) * d;
1635
+ reg[i/2 ][2 *(i%2 )] = (((qs[i] & mask0) ) + m) * d;
1630
1636
reg[i/2 ][2 *(i%2 )+1 ] = (((qs[i] & mask1) >> 8 ) + m) * d;
1631
1637
}
1632
1638
}
@@ -1640,11 +1646,21 @@ void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg
1640
1646
const ushort mask1 = il ? 0xF000 : 0x0F00 ;
1641
1647
1642
1648
for (int i=0 ;i<8 ;i++) {
1643
- reg[i/2 ][2 *(i%2 )] = (((qs[i] & mask0)) * d) + m;
1649
+ reg[i/2 ][2 *(i%2 )] = (((qs[i] & mask0) ) * d) + m;
1644
1650
reg[i/2 ][2 *(i%2 )+1 ] = (((qs[i] & mask1) >> 8 ) * d) + m;
1645
1651
}
1646
1652
}
1647
1653
1654
+ template <typename type4x4>
1655
+ void dequantize_q8_0 (device const block_q8_0 *xb, short il, thread type4x4 & reg) {
1656
+ device const uint8_t * qs = ((device const uint8_t *)xb->qs );
1657
+ const half d = xb->d ;
1658
+
1659
+ for (int i=0 ;i<16 ;i++) {
1660
+ reg[i/4 ][i%4 ] = (qs[i + 16 *il] * d);
1661
+ }
1662
+ }
1663
+
1648
1664
template <typename type4x4>
1649
1665
void dequantize_q2_K (device const block_q2_K *xb, short il, thread type4x4 & reg) {
1650
1666
const half d = xb->d ;
@@ -1947,9 +1963,10 @@ kernel void kernel_mul_mm(device const uchar * src0,
1947
1963
typedef void (get_rows_t )(device const void *, device const int *, device float *, constant int64_t &, \
1948
1964
constant uint64_t &, constant uint64_t &, uint, uint, uint);
1949
1965
1950
- template [[host_name(" kernel_get_rows_f16" )]] kernel get_rows_t kernel_get_rows<half4x4, 1 , dequantize_f16>;
1966
+ template [[host_name(" kernel_get_rows_f16" )]] kernel get_rows_t kernel_get_rows<half4x4, 1 , dequantize_f16>;
1951
1967
template [[host_name(" kernel_get_rows_q4_0" )]] kernel get_rows_t kernel_get_rows<block_q4_0, 2 , dequantize_q4_0>;
1952
1968
template [[host_name(" kernel_get_rows_q4_1" )]] kernel get_rows_t kernel_get_rows<block_q4_1, 2 , dequantize_q4_1>;
1969
+ template [[host_name(" kernel_get_rows_q8_0" )]] kernel get_rows_t kernel_get_rows<block_q8_0, 2 , dequantize_q8_0>;
1953
1970
template [[host_name(" kernel_get_rows_q2_K" )]] kernel get_rows_t kernel_get_rows<block_q2_K, QK_NL, dequantize_q2_K>;
1954
1971
template [[host_name(" kernel_get_rows_q3_K" )]] kernel get_rows_t kernel_get_rows<block_q3_K, QK_NL, dequantize_q3_K>;
1955
1972
template [[host_name(" kernel_get_rows_q4_K" )]] kernel get_rows_t kernel_get_rows<block_q4_K, QK_NL, dequantize_q4_K>;
@@ -1960,7 +1977,7 @@ typedef void (mat_mm_t)(device const uchar *, device const float *, device float
1960
1977
constant int64_t &, constant int64_t &, constant int64_t &, constant int64_t &, \
1961
1978
constant int64_t &, constant int64_t &, constant uint &, threadgroup uchar *, uint3, uint, uint);
1962
1979
1963
- template [[host_name(" kernel_mul_mm_f16_f32" )]] kernel mat_mm_t kernel_mul_mm<half4x4, 1 , dequantize_f16>;
1980
+ template [[host_name(" kernel_mul_mm_f16_f32" )]] kernel mat_mm_t kernel_mul_mm<half4x4, 1 , dequantize_f16>;
1964
1981
template [[host_name(" kernel_mul_mm_q4_0_f32" )]] kernel mat_mm_t kernel_mul_mm<block_q4_0, 2 , dequantize_q4_0>;
1965
1982
template [[host_name(" kernel_mul_mm_q4_1_f32" )]] kernel mat_mm_t kernel_mul_mm<block_q4_1, 2 , dequantize_q4_1>;
1966
1983
template [[host_name(" kernel_mul_mm_q2_K_f32" )]] kernel mat_mm_t kernel_mul_mm<block_q2_K, QK_NL, dequantize_q2_K>;
0 commit comments