@@ -1814,33 +1814,187 @@ inline static void ggml_vec_dot_f32(const int n, float * restrict s, const float
1814
1814
}
1815
1815
1816
1816
#if __AVX512F__ && QK == 32
1817
- static inline __m512 dot_q4_0_oneblock_avx512 (
1817
+ static inline __m512i bytes_from_q4_0_twoblocks_avx512 ( const __m512i blocks ) {
1818
+ // The 64 bytes of `blocks` contain two consecutive Q4_0 blocks loaded from memory:
1819
+ // +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
1820
+ // |63 62 61 60 59 58 57 56 55 54 53 52 51 50 49 48 47 46 45 44 43 42 41 40 39 38 37 36 35 34 33 32|
1821
+ // +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
1822
+ // | :. =_ () [] <> () Zz Yy|
1823
+ // +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
1824
+ // |31 30 29 28 27 26 25 24 23 22 21 20 19 18 17 16 15 14 13 12 11 10 09 08 07 06 05 04 03 02 01 00|
1825
+ // +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
1826
+ // |Xx Ww Vv Uu Tt Ss Rr Qq Pp Oo Nn Mm Ll Kk Jj Ii Hh Gg Ff Ee Dd Cc Bb Aa |
1827
+ // +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
1828
+ //
1829
+ // Bytes 04..19 (block #0) and 24..39 (block #1) both contain 32 nibbles (4-bit unsigned integers).
1830
+ // We have exactly 64 nibbles, so we want to place each nibble into a separate byte.
1831
+ // Bytes 00..03 and 20..23 contain scales, which are irrelevant to this function.
1832
+ // Bytes 40..63 are masked when loading the data, so they are zeroed out.
1833
+ #ifdef __AVX512VBMI__
1834
+ const __m512i byte_perm = _mm512_set_epi8 (
1835
+ 39 , 38 , 39 , 38 , 37 , 36 , 37 , 36 , 35 , 34 , 35 , 34 , 33 , 32 , 33 , 32 ,
1836
+ 31 , 30 , 31 , 30 , 29 , 28 , 29 , 28 , 27 , 26 , 27 , 26 , 25 , 24 , 25 , 24 ,
1837
+ 19 , 18 , 19 , 18 , 17 , 16 , 17 , 16 , 15 , 14 , 15 , 14 , 13 , 12 , 13 , 12 ,
1838
+ 11 , 10 , 11 , 10 , 9 , 8 , 9 , 8 , 7 , 6 , 7 , 6 , 5 , 4 , 5 , 4
1839
+ );
1840
+ const __m512i permuted = _mm512_permutexvar_epi8 ( byte_perm , blocks );
1841
+ // After applying VPERMB, `permuted` looks like this:
1842
+ // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
1843
+ // |63 62 61 60 59 58 57 56 55 54 53 52 51 50 49 48 47 46 45 44 43 42 41 40 39 38 37 36 35 34 33 32|
1844
+ // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
1845
+ // |:. =_ :. =_ () [] () [] <> () <> () Zz Yy Zz Yy Xx Ww Xx Ww Vv Uu Vv Uu Tt Ss Tt Ss Rr Qq Rr Qq|
1846
+ // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
1847
+ // |31 30 29 28 27 26 25 24 23 22 21 20 19 18 17 16 15 14 13 12 11 10 09 08 07 06 05 04 03 02 01 00|
1848
+ // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
1849
+ // |Pp Oo Pp Oo Nn Mm Nn Mm Ll Kk Ll Kk Jj Ii Jj Ii Hh Gg Hh Gg Ff Ee Ff Ee Dd Cc Dd Cc Bb Aa Bb Aa|
1850
+ // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
1851
+ #else
1852
+ const __m512i word_perm = _mm512_set_epi16 (
1853
+ 19 , 19 , 18 , 18 , 17 , 17 , 16 , 16 , 15 , 15 , 14 , 14 , 13 , 13 , 12 , 12 ,
1854
+ 9 , 9 , 8 , 8 , 7 , 7 , 6 , 6 , 5 , 5 , 4 , 4 , 3 , 3 , 2 , 2
1855
+ );
1856
+ const __m512i permuted = _mm512_permutexvar_epi16 ( word_perm , blocks );
1857
+ // This is the fallback path for CPUs that don't support VPERMB. Since we permute 16-bit groups only,
1858
+ // VPERMB can be replaced with VPERMW. We could always use VPERMW, but at least on Tiger Lake and
1859
+ // Ice Lake VPERMW followed by a right shift is quite noticeably slower than VPERMB.
1860
+ #endif
1861
+
1862
+ // Shift every odd-numbered 16-bit group to the right by 4 bits.
1863
+ const __mmask32 shift_mask = 0xaaaaaaaa ;
1864
+ const __m512i shifted = _mm512_mask_srai_epi16 ( permuted , shift_mask , permuted , 4 );
1865
+ // After applying VPSRAW, `shifted` looks like this (the "empty" nibbles are filled with zeroes):
1866
+ // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
1867
+ // |63 62 61 60 59 58 57 56 55 54 53 52 51 50 49 48 47 46 45 44 43 42 41 40 39 38 37 36 35 34 33 32
1868
+ // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
1869
+ // | : .= :. =_ ( )[ () [] < >( <> () Z zY Zz Yy X xW Xx Ww V vU Vv Uu T tS Tt Ss R rQ Rr Qq
1870
+ // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
1871
+ // |31 30 29 28 27 26 25 24 23 22 21 20 19 18 17 16 15 14 13 12 11 10 09 08 07 06 05 04 03 02 01 00|
1872
+ // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
1873
+ // | P pO Pp Oo N nM Nn Mm L lK Ll Kk J jI Jj Ii H hG Hh Gg F fE Ff Ee D dC Dd Cc B bA Bb Aa|
1874
+ // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
1875
+
1876
+ // Now we just need to zero out the higher nibble in each byte, and we're done.
1877
+ const __m512i low_nibble_mask = _mm512_set1_epi8 ( 0xf );
1878
+ return _mm512_and_si512 ( low_nibble_mask , shifted );
1879
+ // The final result looks like this:
1880
+ // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
1881
+ // |63 62 61 60 59 58 57 56 55 54 53 52 51 50 49 48 47 46 45 44 43 42 41 40 39 38 37 36 35 34 33 32|
1882
+ // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
1883
+ // | : = . _ ( [ ) ] < ( > ) Z Y z y X W x w V U v u T S t s R Q r q|
1884
+ // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
1885
+ // |31 30 29 28 27 26 25 24 23 22 21 20 19 18 17 16 15 14 13 12 11 10 09 08 07 06 05 04 03 02 01 00|
1886
+ // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
1887
+ // | P O p o N M n m L K l k J I j i H G h g F E f e D C d c B A b a|
1888
+ // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
1889
+ }
1890
+
1891
+ static inline __m512 dot_q4_0_twoblocks_avx512 (
1818
1892
__m512 acc ,
1819
1893
const block_q4_0 * restrict x ,
1820
1894
const block_q4_0 * restrict y ,
1821
1895
int i
1822
1896
) {
1823
- // Compute combined scale for the block
1824
- __m512 d = _mm512_set1_ps ( x [i ].d * y [i ].d );
1825
-
1826
- __m256i bx = bytesFromNibbles ( x [i ].qs );
1827
- __m256i by = bytesFromNibbles ( y [i ].qs );
1828
-
1829
- // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
1830
- const __m256i off = _mm256_set1_epi8 ( 8 );
1831
- bx = _mm256_sub_epi8 ( bx , off );
1832
- by = _mm256_sub_epi8 ( by , off );
1833
-
1834
- // Sign-extend 16 signed bytes into int16_t
1835
- __m512i x32 = _mm512_cvtepi8_epi16 ( bx );
1836
- __m512i y32 = _mm512_cvtepi8_epi16 ( by );
1837
- // Compute products of int16_t integers, add pairwise
1838
- __m512i i64 = _mm512_madd_epi16 ( x32 , y32 );
1897
+ // A pair of Q4_0 blocks spans 40 bytes, while an AVX-512 register has 64. The remaining 24 bytes
1898
+ // can potentially be unaddressable, so we make sure to mask them out before the load, even though
1899
+ // we don't use them at all. This might hurt the performance slightly, since the compiler is forced
1900
+ // to use e.g. `VMOVDQU64 REG, MASK, [ADDR] + VPERMB ..., REG` instead of just `VPERMB ..., [ADDR]`.
1901
+ const __mmask8 load_mask = 0x1f ;
1902
+ const __m512i blocks_0 = _mm512_maskz_loadu_epi64 ( load_mask , & x [i ] );
1903
+ const __m512i blocks_1 = _mm512_maskz_loadu_epi64 ( load_mask , & y [i ] );
1904
+
1905
+ // We want to multiply the scales, so we interpret both registers as 16 32-bit floats:
1906
+ // +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
1907
+ // | 15 | 14 | 13 | 12 | 11 | 10 | 09 | 08 | 07 | 06 | 05 | 04 | 03 | 02 | 01 | 00 |
1908
+ // +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
1909
+ // blocks_0_float
1910
+ // +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
1911
+ // | | | | | | | xx | xx | xx | xx | B | xx | xx | xx | xx | A |
1912
+ // +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
1913
+ // blocks_1_float
1914
+ // +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
1915
+ // | | | | | | | xx | xx | xx | xx | D | xx | xx | xx | xx | C |
1916
+ // +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
1917
+ const __m512 blocks_0_float = _mm512_castsi512_ps ( blocks_0 );
1918
+ const __m512 blocks_1_float = _mm512_castsi512_ps ( blocks_1 );
1919
+ // We absolutely shouldn't touch the floats marked with `xx`: they contain some
1920
+ // random data, which might very well underflow. At least on Intel, this leads
1921
+ // to a huge penalty that can't be ignored (easily 100x or more) unless you
1922
+ // compile your code with something like `-ffast-math` to enable FTZ/DAZ flags.
1923
+ // (and ggml can't assume that you do)...
1924
+ const __mmask16 scale_mul_mask = 0x21 ;
1925
+ #ifdef __clang__
1926
+ // ...however, clang decides to optimize the multiplication mask away:
1927
+ // https://godbolt.org/z/P8PqdsfvW
1928
+ // gcc and MSVC do the sane thing. This horrible workaround forces clang to emit the mask.
1929
+ __m512i scales ;
1930
+ __asm__(
1931
+ "vmulps %1, %2, %0%{%3%}"
1932
+ : "=v" ( scales )
1933
+ : "vm" ( blocks_0_float ), "v" ( blocks_1_float ), "Yk" ( scale_mul_mask )
1934
+ );
1935
+ #else
1936
+ const __m512 scales = _mm512_maskz_mul_ps ( scale_mul_mask , blocks_0_float , blocks_1_float );
1937
+ #endif
1938
+ const __m512i scale_perm = _mm512_set_epi32 (
1939
+ 5 , 5 , 5 , 5 , 5 , 5 , 5 , 5 ,
1940
+ 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0
1941
+ );
1942
+ const __m512 permuted_scales = _mm512_permutexvar_ps ( scale_perm , scales );
1943
+ // After VMULPS and VPERMPS, `permuted_scales` looks like this:
1944
+ // +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
1945
+ // | 15 | 14 | 13 | 12 | 11 | 10 | 09 | 08 | 07 | 06 | 05 | 04 | 03 | 02 | 01 | 00 |
1946
+ // +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
1947
+ // | B*D| B*D| B*D| B*D| B*D| B*D| B*D| B*D| A*C| A*C| A*C| A*C| A*C| A*C| A*C| A*C|
1948
+ // +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
1949
+
1950
+ const __m512i bytes_0 = bytes_from_q4_0_twoblocks_avx512 ( blocks_0 );
1951
+ const __m512i bytes_1 = bytes_from_q4_0_twoblocks_avx512 ( blocks_1 );
1952
+
1953
+ // Now we want to compute dot products of 4-element byte vectors and store them in
1954
+ // 32-bit integers. That is (only one 4-element vector is shown for clarity):
1955
+ // +----+----+----+----+
1956
+ // ... | 03 | 02 | 01 | 00 |
1957
+ // +----+----+----+----+
1958
+ // bytes_0
1959
+ // +----+----+----+----+
1960
+ // ... | D | C | B | A |
1961
+ // +----+----+----+----+
1962
+ // bytes_1
1963
+ // +----+----+----+----+
1964
+ // ... | H | G | F | E |
1965
+ // +----+----+----+----+
1966
+ // final_res_int
1967
+ // +----+----+----+----+
1968
+ // ... | A*E+B*F+C*G+D*H |
1969
+ // +----+----+----+----+
1970
+ const __m512i plus_8 = _mm512_set1_epi8 ( 8 );
1971
+ const __m512i bytes_1_minus_8 = _mm512_sub_epi8 ( bytes_1 , plus_8 );
1972
+
1973
+ #ifdef __AVX512VNNI__
1974
+ // We have VPDPBUSDS in AVX512-VNNI, which does exactly what we want, but with a catch:
1975
+ // the *left* operand is supposed to be unsigned, while Q4_0 quantization subtracts 8
1976
+ // from each nibble, so they can be negative. So, instead of `(bytes_0 - 8) * (bytes_1 - 8)`,
1977
+ // we compute `bytes_0 * (bytes_1 - 8) + bytes_1 * (-8) + 64`. VPDPBUSDS uses an accumulator,
1978
+ // which means we only need 2 instructions.
1979
+ const __m512i dot_init = _mm512_set1_epi32 ( 4 * 64 );
1980
+ const __m512i minus_8 = _mm512_set1_epi8 ( -8 );
1981
+ const __m512i prod_0 = _mm512_dpbusds_epi32 ( dot_init , bytes_1 , minus_8 );
1982
+ const __m512i final_res_int = _mm512_dpbusds_epi32 ( prod_0 , bytes_0 , bytes_1_minus_8 );
1983
+ #else
1984
+ // As a fallback, we have VPMADDUBSW in AVX512-BW, which uses 16-bit products instead of 32-bit ones.
1985
+ // It has the same catch as VPDPBUSDS: the left operand should be unsigned.
1986
+ // This is essentially the AVX-512 version of the AVX-2 trick used by GH user Const-me
1987
+ // ref: https://gist.github.com/Const-me/4d30e1fc767ab314596e16e90f53b6f4#file-matmultest-cpp-L119
1988
+ const __m512i one = _mm512_set1_epi16 ( 1 );
1989
+ const __m512i prod_0 = _mm512_maddubs_epi16 ( bytes_0 , bytes_1_minus_8 );
1990
+ const __m512i prod_1 = _mm512_maddubs_epi16 ( plus_8 , bytes_1_minus_8 );
1991
+ const __m512i diff = _mm512_sub_epi16 ( prod_0 , prod_1 );
1992
+ const __m512i final_res_int = _mm512_madd_epi16 ( diff , one );
1993
+ #endif
1839
1994
1840
- // Convert int32_t to float
1841
- __m512 p = _mm512_cvtepi32_ps ( i64 );
1842
- // Apply the scale, and accumulate
1843
- return _mm512_fmadd_ps ( d , p , acc );
1995
+ // Finally, we multiply the permuted scales and the 32-bit dot products, then accumulate.
1996
+ const __m512 final_res_float = _mm512_cvtepi32_ps ( final_res_int );
1997
+ return _mm512_fmadd_ps ( permuted_scales , final_res_float , acc );
1844
1998
}
1845
1999
#endif
1846
2000
@@ -1972,25 +2126,26 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
1972
2126
__m512 acc0 = _mm512_setzero_ps ();
1973
2127
__m512 acc1 = _mm512_setzero_ps ();
1974
2128
1975
- const int superblock_size = 8 ;
2129
+ const int superblock_size = 16 ;
2130
+
1976
2131
const int superblock_count = nb / superblock_size ;
1977
2132
1978
2133
for (int superblock_ix = 0 ; superblock_ix < superblock_count ; superblock_ix += 1 ) {
1979
2134
int i = superblock_ix * superblock_size ;
1980
2135
1981
- acc0 = dot_q4_0_oneblock_avx512 ( acc0 , x , y , i + 0 );
1982
- acc1 = dot_q4_0_oneblock_avx512 ( acc1 , x , y , i + 1 );
1983
- acc0 = dot_q4_0_oneblock_avx512 ( acc0 , x , y , i + 2 );
1984
- acc1 = dot_q4_0_oneblock_avx512 ( acc1 , x , y , i + 3 );
1985
- acc0 = dot_q4_0_oneblock_avx512 ( acc0 , x , y , i + 4 );
1986
- acc1 = dot_q4_0_oneblock_avx512 ( acc1 , x , y , i + 5 );
1987
- acc0 = dot_q4_0_oneblock_avx512 ( acc0 , x , y , i + 6 );
1988
- acc1 = dot_q4_0_oneblock_avx512 ( acc1 , x , y , i + 7 );
2136
+ acc0 = dot_q4_0_twoblocks_avx512 ( acc0 , x , y , i + 0 );
2137
+ acc1 = dot_q4_0_twoblocks_avx512 ( acc1 , x , y , i + 2 );
2138
+ acc0 = dot_q4_0_twoblocks_avx512 ( acc0 , x , y , i + 4 );
2139
+ acc1 = dot_q4_0_twoblocks_avx512 ( acc1 , x , y , i + 6 );
2140
+ acc0 = dot_q4_0_twoblocks_avx512 ( acc0 , x , y , i + 8 );
2141
+ acc1 = dot_q4_0_twoblocks_avx512 ( acc1 , x , y , i + 10 );
2142
+ acc0 = dot_q4_0_twoblocks_avx512 ( acc0 , x , y , i + 12 );
2143
+ acc1 = dot_q4_0_twoblocks_avx512 ( acc1 , x , y , i + 14 );
1989
2144
}
1990
2145
1991
2146
// Remainders
1992
- for (int i = superblock_count * superblock_size ; i < nb ; ++ i ) {
1993
- acc0 = dot_q4_0_oneblock_avx512 ( acc0 , x , y , i );
2147
+ for (int i = superblock_count * superblock_size ; i < nb ; i += 2 ) {
2148
+ acc0 = dot_q4_0_twoblocks_avx512 ( acc0 , x , y , i );
1994
2149
}
1995
2150
1996
2151
// Horizontal sum of all lanes of the accumulator
@@ -10907,6 +11062,22 @@ int ggml_cpu_has_avx512(void) {
10907
11062
#endif
10908
11063
}
10909
11064
11065
+ int ggml_cpu_has_avx512_vbmi (void ) {
11066
+ #if defined(__AVX512VBMI__ )
11067
+ return 1 ;
11068
+ #else
11069
+ return 0 ;
11070
+ #endif
11071
+ }
11072
+
11073
+ int ggml_cpu_has_avx512_vnni (void ) {
11074
+ #if defined(__AVX512VNNI__ )
11075
+ return 1 ;
11076
+ #else
11077
+ return 0 ;
11078
+ #endif
11079
+ }
11080
+
10910
11081
int ggml_cpu_has_fma (void ) {
10911
11082
#if defined(__FMA__ )
10912
11083
return 1 ;
0 commit comments