@@ -1506,17 +1506,17 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int
1506
1506
1507
1507
// reference implementation for deterministic creation of model files
1508
1508
static void quantize_row_q8_0c_reference (const float * restrict x , void * restrict y , int k ) {
1509
- assert (k % QK8_0 == 0 );
1510
- const int nb = k / QK8_0 ;
1509
+ assert (k % QK8_0C == 0 );
1510
+ const int nb = k / QK8_0C ;
1511
1511
1512
1512
uint8_t * restrict qs = y ;
1513
1513
float * restrict ds = (float * ) ((uint8_t * ) y + QK8_0C * nb );
1514
1514
1515
1515
for (int i = 0 ; i < nb ; i ++ ) {
1516
1516
float amax = 0.0f ; // absolute max
1517
1517
1518
- for (int l = 0 ; l < QK8_0 ; l ++ ) {
1519
- const float v = x [i * QK8_0 + l ];
1518
+ for (int l = 0 ; l < QK8_0C ; l ++ ) {
1519
+ const float v = x [i * QK8_0C + l ];
1520
1520
amax = MAX (amax , fabsf (v ));
1521
1521
}
1522
1522
@@ -1525,17 +1525,46 @@ static void quantize_row_q8_0c_reference(const float * restrict x, void * restri
1525
1525
1526
1526
ds [i ] = d ;
1527
1527
1528
- for (int l = 0 ; l < QK8_0 ; ++ l ) {
1529
- const float v = x [i * QK8_0 + l ]* id ;
1530
- qs [i * QK8_0 + l ] = roundf (v );
1528
+ for (int l = 0 ; l < QK8_0C ; ++ l ) {
1529
+ const float v = x [i * QK8_0C + l ]* id ;
1530
+ qs [i * QK8_0C + l ] = roundf (v );
1531
1531
}
1532
1532
}
1533
1533
}
1534
1534
1535
1535
static void quantize_row_q8_0c (const float * restrict x , void * restrict vy , int k ) {
1536
- assert (k % QK8_0 == 0 );
1536
+ assert (k % QK8_0C == 0 );
1537
+ const int nb = k / QK8_0C ;
1538
+
1539
+ int8_t * restrict qs = vy ;
1540
+ float * restrict ds = (float * ) ((uint8_t * ) vy + nb * QK8_0C );
1541
+
1542
+ #if __AVX512F__
1543
+ for (int i = 0 ; i < nb ; i ++ ) {
1544
+ const __m512 x0 = _mm512_loadu_ps ( x + i * QK8_0C );
1545
+ const __m512 x1 = _mm512_loadu_ps ( x + i * QK8_0C + QK8_0C /2 );
1546
+
1547
+ // Find absolute max
1548
+ const __m512 x0abs = _mm512_abs_ps (x0 );
1549
+ const __m512 x1abs = _mm512_abs_ps (x1 );
1550
+ const float amax = _mm512_reduce_max_ps (_mm512_max_ps (x0abs , x1abs ));
1551
+
1552
+ const float d = amax / ((1 << 7 ) - 1 );
1553
+ const float id = d ? 1.0f /d : 0.0f ;
1554
+
1555
+ ds [i ] = d ;
1537
1556
1557
+ const __m512 mul = _mm512_set1_ps ( id );
1558
+ const __m512i x0q = _mm512_cvt_roundps_epi32 (_mm512_mul_ps (x0 , mul ), (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC ));
1559
+ const __m512i x1q = _mm512_cvt_roundps_epi32 (_mm512_mul_ps (x1 , mul ), (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC ));
1560
+
1561
+ _mm512_mask_cvtepi32_storeu_epi8 (qs + i * QK8_0C , 0xffff , x0q );
1562
+ _mm512_mask_cvtepi32_storeu_epi8 (qs + i * QK8_0C + QK8_0C /2 , 0xffff , x1q );
1563
+ }
1564
+ #else
1565
+ // scalar
1538
1566
quantize_row_q8_0c_reference (x , vy , k );
1567
+ #endif
1539
1568
}
1540
1569
1541
1570
static void dequantize_row_q4_0 (const void * restrict vx , float * restrict y , int k ) {
@@ -2478,6 +2507,73 @@ inline static void ggml_vec_dot_f32(const int n, float * restrict s, const float
2478
2507
* s = sumf ;
2479
2508
}
2480
2509
2510
+ #if __AVX512F__ && QK4_0 == 32
2511
+
2512
+ // Dot product of four blocks of q4_0c with four blocks of q8_0c
2513
+ static inline __m512 dot_q4_0c_fourblocks_avx512 (
2514
+ __m512 acc ,
2515
+ const uint8_t * restrict xqs ,
2516
+ const float * restrict xds ,
2517
+ const int8_t * restrict yqs ,
2518
+ const float * restrict yds
2519
+ ) {
2520
+ // load quantized bytes
2521
+ // TODO: change back to aligned loads
2522
+ const __m512i xqs0123 = _mm512_loadu_epi64 ( xqs );
2523
+ const __m512i low_nibble_mask = _mm512_set1_epi8 ( 0xf );
2524
+ const __m512i xqs01 = _mm512_and_si512 ( low_nibble_mask , xqs0123 );
2525
+ // TODO: try srlv/i?
2526
+ const __m512i xqs23 = _mm512_and_si512 ( low_nibble_mask , _mm512_srli_epi32 ( xqs0123 , 4 ) );
2527
+ const __m512i yqs01 = _mm512_loadu_epi64 ( yqs );
2528
+ const __m512i yqs23 = _mm512_loadu_epi64 ( yqs + 2 * QK8_0C );
2529
+
2530
+ // load scales
2531
+ const __m512i scale_mask0 = _mm512_set_epi32 (1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 );
2532
+ const __m512i scale_mask1 = _mm512_set_epi32 (3 , 3 , 3 , 3 , 3 , 3 , 3 , 3 , 2 , 2 , 2 , 2 , 2 , 2 , 2 , 2 );
2533
+ const __m128 xyds = _mm_mul_ps (_mm_load_ps (xds ), _mm_load_ps (yds ));
2534
+ const __m512 xyds0123 = _mm512_broadcast_f32x4 (xyds );
2535
+ const __m512 xyds01 = _mm512_permutevar_ps (xyds0123 , scale_mask0 );
2536
+ const __m512 xyds23 = _mm512_permutevar_ps (xyds0123 , scale_mask1 );
2537
+
2538
+ // take dot product of x and y bytes
2539
+ const __m512i plus_8 = _mm512_set1_epi8 ( 8 );
2540
+ #ifdef __AVX512VNNI__
2541
+ // We have VPDPBUSDS in AVX512-VNNI, which does exactly what we want, but with a catch:
2542
+ // the *left* operand is supposed to be unsigned, while Q4_0 quantization subtracts 8
2543
+ // from each nibble, so they can be negative. So, instead of `(xqs01 - 8) * yqs01`,
2544
+ // we compute `xqs01 * yqs01 - 8 * yqks`.
2545
+ const __m512i zero = _mm512_setzero_epi32 ();
2546
+ const __m512i yqs01_mul8 = _mm512_dpbusds_epi32 ( zero , plus_8 , yqs01 );
2547
+ const __m512i yqs23_mul8 = _mm512_dpbusds_epi32 ( zero , plus_8 , yqs23 );
2548
+ const __m512i xy01 = _mm512_dpbusds_epi32 ( zero , xqs01 , yqs01 );
2549
+ const __m512i xy23 = _mm512_dpbusds_epi32 ( zero , xqs23 , yqs23 );
2550
+ const __m512i res0_int = _mm512_sub_epi32 ( xy01 , yqs01_mul8 );
2551
+ const __m512i res1_int = _mm512_sub_epi32 ( xy23 , yqs23_mul8 );
2552
+ #else
2553
+ // As a fallback, we have VPMADDUBSW in AVX512-BW, which uses 16-bit products instead of 32-bit ones.
2554
+ // It has the same catch as VPDPBUSDS: the left operand should be unsigned.
2555
+ // This is essentially the AVX-512 version of the AVX-2 trick used by GH user Const-me
2556
+ // ref: https://gist.github.com/Const-me/4d30e1fc767ab314596e16e90f53b6f4#file-matmultest-cpp-L119
2557
+ const __m512i one = _mm512_set1_epi16 ( 1 );
2558
+ const __m512i prod_0 = _mm512_maddubs_epi16 ( xqs01 , yqs01 );
2559
+ const __m512i prod_1 = _mm512_maddubs_epi16 ( plus_8 , yqs01 );
2560
+ const __m512i prod_2 = _mm512_maddubs_epi16 ( xqs23 , yqs23 );
2561
+ const __m512i prod_3 = _mm512_maddubs_epi16 ( plus_8 , yqs23 );
2562
+ const __m512i diff0 = _mm512_sub_epi16 ( prod_0 , prod_1 );
2563
+ const __m512i diff1 = _mm512_sub_epi16 ( prod_2 , prod_3 );
2564
+ const __m512i res0_int = _mm512_madd_epi16 ( diff0 , one );
2565
+ const __m512i res1_int = _mm512_madd_epi16 ( diff1 , one );
2566
+ #endif
2567
+
2568
+ // Finally, we multiply the permuted scales and the 32-bit dot products, then accumulate.
2569
+ const __m512 res0_float = _mm512_cvtepi32_ps ( res0_int );
2570
+ const __m512 res1_float = _mm512_cvtepi32_ps ( res1_int );
2571
+
2572
+ return _mm512_fmadd_ps ( xyds23 , res1_float ,
2573
+ _mm512_fmadd_ps ( xyds01 , res0_float , acc ));
2574
+ }
2575
+ #endif
2576
+
2481
2577
inline static void ggml_vec_dot_f16 (const int n , float * restrict s , ggml_fp16_t * restrict x , ggml_fp16_t * restrict y ) {
2482
2578
ggml_float sumf = 0.0 ;
2483
2579
@@ -2721,6 +2817,15 @@ static void ggml_vec_dot_q4_0c_q8_0c(const int n, float * restrict s, const void
2721
2817
2722
2818
float sumf = 0.0 ;
2723
2819
2820
+ #if __AVX512F__
2821
+ // Initialize accumulator with zeros
2822
+ __m512 acc = _mm512_setzero_ps ();
2823
+ for (int i = 0 ; i < nb ; i += 4 ) {
2824
+ acc = dot_q4_0c_fourblocks_avx512 (acc , xqs + i * QK4_0 /2 , xds + i , yqs + i * QK8_0 , yds + i );
2825
+ }
2826
+ // Horizontal sum of all lanes of the accumulator
2827
+ sumf = _mm512_reduce_add_ps ( acc );
2828
+ #else
2724
2829
// scalar
2725
2830
for (int i = 0 ; i < nb /2 ; i ++ ) {
2726
2831
const int dst0 = i + i /2 * 2 ; // 0, 1, 4, 5, 8, 9, ...
@@ -2731,23 +2836,25 @@ static void ggml_vec_dot_q4_0c_q8_0c(const int n, float * restrict s, const void
2731
2836
const float dy0 = yds [dst0 ];
2732
2837
const float dy1 = yds [dst1 ];
2733
2838
2734
- int sumi0 = 0 ;
2735
- int sumi1 = 0 ;
2839
+ // NOTE: having these as plain int triggers a bug with AVX512 on GCC 12.2
2840
+ int64_t sumi0 = 0 ;
2841
+ int64_t sumi1 = 0 ;
2736
2842
2737
2843
for (int l = 0 ; l < QK4_0 ; l ++ ) {
2738
- const uint8_t v0 = xqs [i * QK4_0 + l ];
2844
+ const uint8_t v0 = xqs [i * QK4_0 + l ];
2739
2845
2740
- const int i0 = (int8_t ) (v0 & 0xf ) - 8 ;
2741
- const int i1 = (int8_t ) (v0 >> 4 ) - 8 ;
2846
+ const int i0 = (int ) (v0 & 0xf ) - 8 ;
2847
+ const int i1 = (int ) (v0 >> 4 ) - 8 ;
2742
2848
2743
- const int i2 = yqs [dst0 * QK4_0 + l ];
2744
- const int i3 = yqs [dst1 * QK4_0 + l ];
2849
+ const int i2 = yqs [dst0 * QK4_0 + l ];
2850
+ const int i3 = yqs [dst1 * QK4_0 + l ];
2745
2851
2746
- sumi0 += i0 * i2 ;
2747
- sumi1 += i1 * i3 ;
2852
+ sumi0 += i0 * i2 ;
2853
+ sumi1 += i1 * i3 ;
2748
2854
}
2749
2855
sumf += dx0 * dy0 * sumi0 + dx1 * dy1 * sumi1 ;
2750
2856
}
2857
+ #endif
2751
2858
2752
2859
* s = sumf ;
2753
2860
}
0 commit comments