@@ -1642,39 +1642,39 @@ kernel void kernel_mul_mat_q5_K_f32(
1642
1642
constant int64_t & ne00,
1643
1643
constant int64_t & ne10,
1644
1644
constant int64_t & ne0,
1645
- threadgroup float * sum [[threadgroup(0 )]],
1646
1645
uint2 tgpig[[threadgroup_position_in_grid]],
1647
- uint2 tpitg[[thread_position_in_threadgroup ]],
1648
- uint2 tptg[[threads_per_threadgroup ]]) {
1646
+ uint tiisg[[thread_index_in_simdgroup ]],
1647
+ uint sgitg[[simdgroup_index_in_threadgroup ]]) {
1649
1648
1650
1649
const int nb = ne00/QK_K;
1651
1650
1652
1651
const int64_t r0 = tgpig.x ;
1653
1652
const int64_t r1 = tgpig.y ;
1654
1653
1655
- device const block_q5_K * x = (device const block_q5_K *) src0 + r0*nb;
1654
+ const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2 ;
1655
+
1656
+ device const block_q5_K * x = (device const block_q5_K *) src0 + first_row*nb;
1656
1657
device const float * yy = (device const float *) src1 + r1*ne10;
1657
1658
1658
- const int nth = tptg.x *tptg.y ;
1659
- const int ith = tptg.y *tpitg.x + tpitg.y ;
1659
+ float sumf[2 ]={0 .f };
1660
1660
1661
- float sumf = 0 ;
1661
+ const int step = sizeof (block_q5_K) * nb ;
1662
1662
1663
1663
#if QK_K == 256
1664
+ #
1665
+ float yl[16 ], yh[16 ];
1664
1666
1665
1667
const uint16_t kmask1 = 0x3f3f ;
1666
1668
const uint16_t kmask2 = 0x0f0f ;
1667
1669
const uint16_t kmask3 = 0xc0c0 ;
1668
1670
1669
- const int tid = tpitg.y ; // 0...16
1670
- const int il = tid/4 ; // 0...3
1671
- const int ir = tid - 4 *il;// 0...3
1672
- const int n = 4 ;
1673
-
1674
- const int im = il/2 ; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
1675
- const int in = il%2 ;
1671
+ const int tid = tiisg/4 ;
1672
+ const int ix = tiisg%4 ;
1673
+ const int im = tid/4 ;
1674
+ const int ir = tid%4 ;
1675
+ const int n = 8 ;
1676
1676
1677
- const int l0 = n*( 2 *ir + in) ;
1677
+ const int l0 = n*ir ;
1678
1678
const int q_offset = 32 *im + l0;
1679
1679
const int y_offset = 64 *im + l0;
1680
1680
@@ -1683,78 +1683,114 @@ kernel void kernel_mul_mat_q5_K_f32(
1683
1683
const uint8_t hm3 = hm1 << 4 ;
1684
1684
const uint8_t hm4 = hm2 << 4 ;
1685
1685
1686
- uchar2 sc1, sc2, sc3, sc4;
1686
+ uint16_t sc16[4 ];
1687
+ thread const uint8_t * sc8 = (thread const uint8_t *)sc16;
1687
1688
1688
- for ( int i = tpitg. x ; i < nb; i += tptg. x ) {
1689
+ device const float * y1 = yy + ix*QK_K + y_offset;
1689
1690
1690
- device const uint8_t * q1 = (x + i)->qs + q_offset;
1691
- device const uint8_t * q2 = q1 + 64 ;
1692
- device const uint8_t * qh = (x + i)->qh + l0;
1693
- device const float * y1 = yy + i*QK_K + y_offset;
1694
- device const float * y2 = y1 + 128 ;
1691
+ for (int i = ix; i < nb; i += 4 ) {
1695
1692
1696
- const float dall = (float )((x + i)->d );
1697
- const float dmin = (float )((x + i)->dmin );
1693
+ device const uint8_t * q1 = x[i].qs + q_offset;
1694
+ device const uint8_t * qh = x[i].qh + l0;
1695
+ device const half * dh = &x[i].d ;
1696
+ device const uint16_t * a = (device const uint16_t *)x[i].scales + im;
1698
1697
1699
- device const uint16_t * a = (device const uint16_t *)(x + i)->scales ;
1700
- sc1 = as_type<uchar2>((uint16_t )(a[im+0 ] & kmask1));
1701
- sc2 = as_type<uchar2>((uint16_t )(a[im+2 ] & kmask1));
1702
- sc3 = as_type<uchar2>((uint16_t )(((a[im+4 ] >> 0 ) & kmask2) | ((a[im+0 ] & kmask3) >> 2 )));
1703
- sc4 = as_type<uchar2>((uint16_t )(((a[im+4 ] >> 4 ) & kmask2) | ((a[im+2 ] & kmask3) >> 2 )));
1698
+ device const float * y2 = y1 + 128 ;
1699
+ float4 sumy = {0 .f , 0 .f , 0 .f , 0 .f };
1700
+ for (int l = 0 ; l < 8 ; ++l) {
1701
+ yl[l+0 ] = y1[l+ 0 ]; sumy[0 ] += yl[l+0 ];
1702
+ yl[l+8 ] = y1[l+32 ]; sumy[1 ] += yl[l+8 ];
1703
+ yh[l+0 ] = y2[l+ 0 ]; sumy[2 ] += yh[l+0 ];
1704
+ yh[l+8 ] = y2[l+32 ]; sumy[3 ] += yh[l+8 ];
1705
+ }
1704
1706
1705
- float4 s = {0 .f , 0 .f , 0 .f , 0 .f };
1706
- float smin = 0 ;
1707
- for (int l = 0 ; l < n; ++l) {
1707
+ for (int row = 0 ; row < 2 ; ++row) {
1708
+
1709
+ device const uint8_t * q2 = q1 + 64 ;
1710
+
1711
+ sc16[0 ] = a[0 ] & kmask1;
1712
+ sc16[1 ] = a[2 ] & kmask1;
1713
+ sc16[2 ] = ((a[4 ] >> 0 ) & kmask2) | ((a[0 ] & kmask3) >> 2 );
1714
+ sc16[3 ] = ((a[4 ] >> 4 ) & kmask2) | ((a[2 ] & kmask3) >> 2 );
1715
+
1716
+ float4 acc = {0 .f , 0 .f , 0 .f , 0 .f };
1717
+ for (int l = 0 ; l < n; ++l) {
1718
+ uint8_t h = qh[l];
1719
+ acc[0 ] += yl[l+0 ] * ((uint16_t )(q1[l] & 0x0F ) + (h & hm1 ? 16 : 0 ));
1720
+ acc[1 ] += yl[l+8 ] * ((uint16_t )(q1[l] & 0xF0 ) + (h & hm2 ? 256 : 0 ));
1721
+ acc[2 ] += yh[l+0 ] * ((uint16_t )(q2[l] & 0x0F ) + (h & hm3 ? 16 : 0 ));
1722
+ acc[3 ] += yh[l+8 ] * ((uint16_t )(q2[l] & 0xF0 ) + (h & hm4 ? 256 : 0 ));
1723
+ }
1724
+ const float dall = dh[0 ];
1725
+ const float dmin = dh[1 ];
1726
+ sumf[row] += dall * (acc[0 ] * sc8[0 ] + acc[1 ] * sc8[1 ] * 1 .f /16 .f + acc[2 ] * sc8[4 ] + acc[3 ] * sc8[5 ] * 1 .f /16 .f ) -
1727
+ dmin * (sumy[0 ] * sc8[2 ] + sumy[1 ] * sc8[3 ] + sumy[2 ] * sc8[6 ] + sumy[3 ] * sc8[7 ]);
1708
1728
1709
- s[0 ] += y1[l+ 0 ] * ((q1[l] & 0xF ) + (qh[l] & hm1 ? 16 : 0 ));
1710
- s[1 ] += y1[l+32 ] * ((q1[l] >> 4 ) + (qh[l] & hm2 ? 16 : 0 ));
1711
- s[2 ] += y2[l+ 0 ] * ((q2[l] & 0xF ) + (qh[l] & hm3 ? 16 : 0 ));
1712
- s[3 ] += y2[l+32 ] * ((q2[l] >> 4 ) + (qh[l] & hm4 ? 16 : 0 ));
1713
- smin += y1[l] * sc2[0 ] + y1[l+32 ] * sc2[1 ] + y2[l] * sc4[0 ] + y2[l+32 ] * sc4[1 ];
1729
+ q1 += step;
1730
+ qh += step;
1731
+ dh += step/2 ;
1732
+ a += step/2 ;
1714
1733
1715
1734
}
1716
- sumf += dall * (s[0 ] * sc1[0 ] + s[1 ] * sc1[1 ] + s[2 ] * sc3[0 ] + s[3 ] * sc3[1 ]) - dmin * smin;
1735
+
1736
+ y1 += 4 * QK_K;
1717
1737
1718
1738
}
1719
1739
#else
1720
- const int il = 4 * tpitg.x ; // 0, 4, 8, 12
1721
- const int im = il/8 ; // 0, 0, 1, 1
1722
- const int in = il%8 ; // 0, 4, 0, 4
1740
+ float yl[8 ], yh[8 ];
1723
1741
1724
- for (int i = tpitg.y ; i < nb; i += tptg.y ) {
1742
+ const int il = 4 * (tiisg/8 ); // 0, 4, 8, 12
1743
+ const int ix = tiisg%8 ;
1744
+ const int im = il/8 ; // 0, 0, 1, 1
1745
+ const int in = il%8 ; // 0, 4, 0, 4
1725
1746
1726
- const float d = (float )x[i].d ;
1747
+ device const float * y = yy + ix*QK_K + il;
1748
+
1749
+ for (int i = ix; i < nb; i += 8 ) {
1750
+
1751
+ float4 sumy = {0 .f , 0 .f , 0 .f , 0 .f };
1752
+ for (int l = 0 ; l < 4 ; ++l) {
1753
+ yl[l+0 ] = y[l+ 0 ];
1754
+ yl[l+4 ] = y[l+16 ];
1755
+ yh[l+0 ] = y[l+32 ];
1756
+ yh[l+4 ] = y[l+48 ];
1757
+ }
1758
+
1759
+ device const half * dh = &x[i].d ;
1727
1760
device const uint8_t * q = x[i].qs + il;
1728
1761
device const uint8_t * h = x[i].qh + in;
1729
1762
device const int8_t * s = x[i].scales ;
1730
- device const float * y = yy + i*QK_K + il;
1731
1763
1732
- for (int l = 0 ; l < 4 ; ++l) {
1733
- const uint8_t hl = h[l] >> im;
1734
- sumf += y[l+ 0 ] * d * s[0 ] * ((q[l+ 0 ] & 0xF ) - (hl & 0x01 ? 0 : 16 ))
1735
- + y[l+16 ] * d * s[1 ] * ((q[l+16 ] & 0xF ) - (hl & 0x04 ? 0 : 16 ))
1736
- + y[l+32 ] * d * s[2 ] * ((q[l+ 0 ] >> 4 ) - (hl & 0x10 ? 0 : 16 ))
1737
- + y[l+48 ] * d * s[3 ] * ((q[l+16 ] >> 4 ) - (hl & 0x40 ? 0 : 16 ));
1764
+ for (int row = 0 ; row < 2 ; ++row) {
1765
+
1766
+ const float d = dh[0 ];
1767
+
1768
+ float2 acc = {0 .f , 0 .f };
1769
+ for (int l = 0 ; l < 4 ; ++l) {
1770
+ const uint8_t hl = h[l] >> im;
1771
+ acc[0 ] += yl[l+0 ] * s[0 ] * ((int16_t )(q[l+ 0 ] & 0x0F ) - (hl & 0x01 ? 0 : 16 ))
1772
+ + yl[l+4 ] * s[1 ] * ((int16_t )(q[l+16 ] & 0x0F ) - (hl & 0x04 ? 0 : 16 ));
1773
+ acc[1 ] += yh[l+0 ] * s[2 ] * ((int16_t )(q[l+ 0 ] & 0xF0 ) - (hl & 0x10 ? 0 : 256 ))
1774
+ + yh[l+4 ] * s[3 ] * ((int16_t )(q[l+16 ] & 0xF0 ) - (hl & 0x40 ? 0 : 256 ));
1775
+ }
1776
+ sumf[row] += d * (acc[0 ] + 1 .f /16 .f * acc[1 ]);
1777
+
1778
+ q += step;
1779
+ h += step;
1780
+ s += step;
1781
+ dh += step/2 ;
1782
+
1738
1783
}
1784
+
1785
+ y += 8 * QK_K;
1739
1786
}
1740
1787
#endif
1741
- sum[ith] = sumf;
1742
1788
1743
- //
1744
- // Accumulate the sum from all threads in the threadgroup
1745
- //
1746
- threadgroup_barrier (mem_flags::mem_threadgroup);
1747
- if (ith%4 == 0 ) {
1748
- sum[ith] += sum[ith+1 ] + sum[ith+2 ] + sum[ith+3 ];
1749
- }
1750
- threadgroup_barrier (mem_flags::mem_threadgroup);
1751
- if (ith%16 == 0 ) {
1752
- sum[ith] += sum[ith+4 ] + sum[ith+8 ] + sum[ith+12 ];
1753
- }
1754
- threadgroup_barrier (mem_flags::mem_threadgroup);
1755
- if (ith == 0 ) {
1756
- for (int i = 16 ; i < nth; i += 16 ) sum[0 ] += sum[i];
1757
- dst[r1*ne0 + r0] = sum[0 ];
1789
+ for (int row = 0 ; row < 2 ; ++row) {
1790
+ const float tot = simd_sum (sumf[row]);
1791
+ if (tiisg == 0 ) {
1792
+ dst[r1*ne0 + first_row + row] = tot;
1793
+ }
1758
1794
}
1759
1795
1760
1796
}
@@ -1766,10 +1802,9 @@ kernel void kernel_mul_mat_q6_K_f32(
1766
1802
constant int64_t & ne00,
1767
1803
constant int64_t & ne10,
1768
1804
constant int64_t & ne0,
1769
- threadgroup float * sum [[threadgroup(0 )]],
1770
1805
uint2 tgpig[[threadgroup_position_in_grid]],
1771
- uint2 tpitg[[thread_position_in_threadgroup ]],
1772
- uint2 tptg[[threads_per_threadgroup ]]) {
1806
+ uint tiisg[[thread_index_in_simdgroup ]],
1807
+ uint sgitg[[simdgroup_index_in_threadgroup ]]) {
1773
1808
1774
1809
const uint8_t kmask1 = 0x03 ;
1775
1810
const uint8_t kmask2 = 0x0C ;
@@ -1781,19 +1816,18 @@ kernel void kernel_mul_mat_q6_K_f32(
1781
1816
const int64_t r0 = tgpig.x ;
1782
1817
const int64_t r1 = tgpig.y ;
1783
1818
1784
- device const block_q6_K * x = (device const block_q6_K *) src0 + r0*nb;
1785
- device const float * yy = (device const float *) src1 + r1*ne10;
1819
+ const int row = 2 * r0 + sgitg;
1786
1820
1787
- const int nth = tptg. x *tptg. y ;
1788
- const int ith = tptg. y *tpitg. x + tpitg. y ;
1821
+ device const block_q6_K * x = (device const block_q6_K *) src0 + row * nb; // r0*nb ;
1822
+ device const float * yy = (device const float *) src1 + r1*ne10 ;
1789
1823
1790
1824
float sumf = 0 ;
1791
1825
1792
1826
#if QK_K == 256
1793
- // Note: we absolutely assume that tptg.y = 16 and QK_K = 256!
1794
- const int iqs = 16 * tpitg. y ;
1795
- const int ip = iqs / 128 ; // 0 or 1
1796
- const int il = (iqs - 128 *ip)/ 16 ; // 0...7
1827
+ const int tid = tiisg/ 2 ;
1828
+ const int ix = tiisg% 2 ;
1829
+ const int ip = tid/ 8 ; // 0 or 1
1830
+ const int il = tid% 8 ;
1797
1831
const int n = 4 ;
1798
1832
const int l0 = n*il;
1799
1833
const int is = 8 *ip + l0/16 ;
@@ -1802,9 +1836,10 @@ kernel void kernel_mul_mat_q6_K_f32(
1802
1836
const int q_offset_l = 64 *ip + l0;
1803
1837
const int q_offset_h = 32 *ip + l0;
1804
1838
1805
- for (int i = tpitg. x ; i < nb; i += tptg. x ) {
1839
+ for (int i = ix ; i < nb; i += 2 ) {
1806
1840
1807
- device const uint8_t * ql = x[i].ql + q_offset_l;
1841
+ device const uint8_t * q1 = x[i].ql + q_offset_l;
1842
+ device const uint8_t * q2 = q1 + 32 ;
1808
1843
device const uint8_t * qh = x[i].qh + q_offset_h;
1809
1844
device const int8_t * sc = x[i].scales + is;
1810
1845
@@ -1814,19 +1849,21 @@ kernel void kernel_mul_mat_q6_K_f32(
1814
1849
1815
1850
float4 sums = {0 .f , 0 .f , 0 .f , 0 .f };
1816
1851
for (int l = 0 ; l < n; ++l) {
1817
- sums[0 ] += y[l+ 0 ] * ((int8_t )((ql[l+ 0 ] & 0xF ) | ((qh[l] & kmask1) << 4 )) - 32 );
1818
- sums[1 ] += y[l+32 ] * ((int8_t )((ql[l+ 32 ] & 0xF ) | ((qh[l] & kmask2) << 2 )) - 32 );
1819
- sums[2 ] += y[l+64 ] * ((int8_t )((ql[l+ 0 ] >> 4 ) | ((qh[l] & kmask3) << 0 )) - 32 );
1820
- sums[3 ] += y[l+96 ] * ((int8_t )((ql[l+ 32 ] >> 4 ) | ((qh[l] & kmask4) >> 2 )) - 32 );
1852
+ sums[0 ] += y[l+ 0 ] * ((int8_t )((q1[l ] & 0xF ) | ((qh[l] & kmask1) << 4 )) - 32 );
1853
+ sums[1 ] += y[l+32 ] * ((int8_t )((q2[l ] & 0xF ) | ((qh[l] & kmask2) << 2 )) - 32 );
1854
+ sums[2 ] += y[l+64 ] * ((int8_t )((q1[l ] >> 4 ) | ((qh[l] & kmask3) << 0 )) - 32 );
1855
+ sums[3 ] += y[l+96 ] * ((int8_t )((q2[l ] >> 4 ) | ((qh[l] & kmask4) >> 2 )) - 32 );
1821
1856
}
1822
1857
1823
1858
sumf += dall * (sums[0 ] * sc[0 ] + sums[1 ] * sc[2 ] + sums[2 ] * sc[4 ] + sums[3 ] * sc[6 ]);
1824
1859
1825
1860
}
1861
+
1826
1862
#else
1827
- const int il = 4 *tpitg.x ; // 0, 4, 8, 12
1863
+ const int ix = tiisg/4 ;
1864
+ const int il = 4 *(tiisg%4 );
1828
1865
1829
- for (int i = tpitg. y ; i < nb; i += tptg. y ) {
1866
+ for (int i = ix ; i < nb; i += 8 ) {
1830
1867
device const float * y = yy + i * QK_K + il;
1831
1868
device const uint8_t * ql = x[i].ql + il;
1832
1869
device const uint8_t * qh = x[i].qh + il;
@@ -1846,23 +1883,8 @@ kernel void kernel_mul_mat_q6_K_f32(
1846
1883
1847
1884
#endif
1848
1885
1849
- sum[ith] = sumf;
1850
-
1851
- //
1852
- // Accumulate the sum from all threads in the threadgroup
1853
- //
1854
- threadgroup_barrier (mem_flags::mem_threadgroup);
1855
- if (ith%4 == 0 ) {
1856
- for (int i = 1 ; i < 4 ; ++i) sum[ith] += sum[ith + i];
1857
- }
1858
- threadgroup_barrier (mem_flags::mem_threadgroup);
1859
- if (ith%16 == 0 ) {
1860
- for (int i = 4 ; i < 16 ; i += 4 ) sum[ith] += sum[ith + i];
1861
- }
1862
- threadgroup_barrier (mem_flags::mem_threadgroup);
1863
- if (ith == 0 ) {
1864
- for (int i = 16 ; i < nth; i += 16 ) sum[0 ] += sum[i];
1865
- dst[r1*ne0 + r0] = sum[0 ];
1886
+ const float tot = simd_sum (sumf);
1887
+ if (tiisg == 0 ) {
1888
+ dst[r1*ne0 + row] = tot;
1866
1889
}
1867
-
1868
1890
}
0 commit comments