@@ -580,7 +580,63 @@ static inline __m128i packNibbles( __m256i bytes )
580
580
return _mm_packus_epi16 ( r0 , r1 );
581
581
#endif
582
582
}
583
- #else
583
+ #elif defined(__AVX__ )
584
+ // spread 32 bits to 32 bytes { 0x00, 0xFF }
585
+ static inline __m256i bytes_from_bits_32 (const uint8_t * x ) {
586
+ uint32_t x32 ;
587
+ memcpy (& x32 , x , sizeof (uint32_t ));
588
+ const __m128i shuf_maskl = _mm_set_epi64x (0x0101010101010101 , 0x0000000000000000 );
589
+ const __m128i shuf_maskh = _mm_set_epi64x (0x0303030303030303 , 0x0202020202020202 );
590
+ __m128i bytesl = _mm_shuffle_epi8 (_mm_set1_epi32 (x32 ), shuf_maskl );
591
+ __m128i bytesh = _mm_shuffle_epi8 (_mm_set1_epi32 (x32 ), shuf_maskh );
592
+ const __m128i bit_mask = _mm_set1_epi64x (0x7fbfdfeff7fbfdfe );
593
+ bytesl = _mm_or_si128 (bytesl , bit_mask );
594
+ bytesh = _mm_or_si128 (bytesh , bit_mask );
595
+ bytesl = _mm_cmpeq_epi8 (bytesl , _mm_set1_epi64x (-1 ));
596
+ bytesh = _mm_cmpeq_epi8 (bytesh , _mm_set1_epi64x (-1 ));
597
+ return _mm256_set_m128i (bytesh , bytesl );
598
+ }
599
+
600
+ // Unpack 32 4-bit fields into 32 bytes
601
+ // The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval
602
+ static inline __m256i bytes_from_nibbles_32 (const uint8_t * rsi )
603
+ {
604
+ // Load 16 bytes from memory
605
+ __m128i tmpl = _mm_loadu_si128 ((const __m128i * )rsi );
606
+ __m128i tmph = _mm_srli_epi16 (tmpl , 4 );
607
+ const __m128i lowMask = _mm_set1_epi8 (0xF );
608
+ tmpl = _mm_and_si128 (lowMask , tmpl );
609
+ tmph = _mm_and_si128 (lowMask , tmph );
610
+ return _mm256_set_m128i (tmph , tmpl );
611
+ }
612
+
613
+ // add int16_t pairwise and return as float vector
614
+ static inline __m256 sum_i16_pairs_float (const __m128i xh , const __m128i xl ) {
615
+ const __m128i ones = _mm_set1_epi16 (1 );
616
+ const __m128i summed_pairsl = _mm_madd_epi16 (ones , xl );
617
+ const __m128i summed_pairsh = _mm_madd_epi16 (ones , xh );
618
+ const __m256i summed_pairs = _mm256_set_m128i (summed_pairsh , summed_pairsl );
619
+ return _mm256_cvtepi32_ps (summed_pairs );
620
+ }
621
+
622
+ // multiply int8_t, add results pairwise twice and return as float vector
623
+ static inline __m256 mul_sum_i8_pairs_float (const __m256i x , const __m256i y ) {
624
+ const __m128i xl = _mm256_castsi256_si128 (x );
625
+ const __m128i xh = _mm256_extractf128_si256 (x , 1 );
626
+ const __m128i yl = _mm256_castsi256_si128 (y );
627
+ const __m128i yh = _mm256_extractf128_si256 (y , 1 );
628
+ // Get absolute values of x vectors
629
+ const __m128i axl = _mm_sign_epi8 (xl , xl );
630
+ const __m128i axh = _mm_sign_epi8 (xh , xh );
631
+ // Sign the values of the y vectors
632
+ const __m128i syl = _mm_sign_epi8 (yl , xl );
633
+ const __m128i syh = _mm_sign_epi8 (yh , xh );
634
+ // Perform multiplication and create 16-bit values
635
+ const __m128i dotl = _mm_maddubs_epi16 (axl , syl );
636
+ const __m128i doth = _mm_maddubs_epi16 (axh , syh );
637
+ return sum_i16_pairs_float (doth , dotl );
638
+ }
639
+
584
640
static inline __m128i packNibbles ( __m128i bytes1 , __m128i bytes2 )
585
641
{
586
642
// Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh
@@ -2355,7 +2411,7 @@ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void *
2355
2411
}
2356
2412
2357
2413
* s = vaddvq_f32 (sumv0 ) + vaddvq_f32 (sumv1 ) + summs ;
2358
- #elif defined(__AVX2__ )
2414
+ #elif defined(__AVX2__ ) || defined( __AVX__ )
2359
2415
// Initialize accumulator with zeros
2360
2416
__m256 acc = _mm256_setzero_ps ();
2361
2417
@@ -2381,7 +2437,11 @@ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void *
2381
2437
const __m256 xy = mul_sum_i8_pairs_float (bx , by );
2382
2438
2383
2439
// Accumulate d0*d1*x*y
2440
+ #if defined(__AVX2__ )
2384
2441
acc = _mm256_fmadd_ps ( d0d1 , xy , acc );
2442
+ #else
2443
+ acc = _mm256_add_ps ( _mm256_mul_ps ( d0d1 , xy ), acc );
2444
+ #endif
2385
2445
}
2386
2446
2387
2447
* s = hsum_float_8 (acc ) + summs ;
@@ -2592,6 +2652,37 @@ static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void *
2592
2652
acc = _mm256_fmadd_ps (d , q , acc );
2593
2653
}
2594
2654
2655
+ * s = hsum_float_8 (acc );
2656
+ #elif defined(__AVX__ )
2657
+ // Initialize accumulator with zeros
2658
+ __m256 acc = _mm256_setzero_ps ();
2659
+ __m128i mask = _mm_set1_epi8 ((char )0xF0 );
2660
+
2661
+ // Main loop
2662
+ for (int i = 0 ; i < nb ; i ++ ) {
2663
+ /* Compute combined scale for the block */
2664
+ const __m256 d = _mm256_mul_ps (_mm256_set1_ps (GGML_FP16_TO_FP32 (x [i ].d )), _mm256_broadcast_ss (& y [i ].d ));
2665
+
2666
+ __m256i bx = bytes_from_nibbles_32 (x [i ].qs );
2667
+ const __m256i bxhi = bytes_from_bits_32 (x [i ].qh );
2668
+ __m128i bxhil = _mm256_castsi256_si128 (bxhi );
2669
+ __m128i bxhih = _mm256_extractf128_si256 (bxhi , 1 );
2670
+ bxhil = _mm_andnot_si128 (bxhil , mask );
2671
+ bxhih = _mm_andnot_si128 (bxhih , mask );
2672
+ __m128i bxl = _mm256_castsi256_si128 (bx );
2673
+ __m128i bxh = _mm256_extractf128_si256 (bx , 1 );
2674
+ bxl = _mm_or_si128 (bxl , bxhil );
2675
+ bxh = _mm_or_si128 (bxh , bxhih );
2676
+ bx = _mm256_set_m128i (bxh , bxl );
2677
+
2678
+ const __m256i by = _mm256_loadu_si256 ((const __m256i * )y [i ].qs );
2679
+
2680
+ const __m256 q = mul_sum_i8_pairs_float (bx , by );
2681
+
2682
+ /* Multiply q with scale and accumulate */
2683
+ acc = _mm256_add_ps (_mm256_mul_ps (d , q ), acc );
2684
+ }
2685
+
2595
2686
* s = hsum_float_8 (acc );
2596
2687
#else
2597
2688
// scalar
@@ -2820,6 +2911,40 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void *
2820
2911
acc = _mm256_fmadd_ps (q , _mm256_mul_ps (dx , dy ), acc );
2821
2912
}
2822
2913
2914
+ * s = hsum_float_8 (acc ) + summs ;
2915
+ #elif defined(__AVX__ )
2916
+ // Initialize accumulator with zeros
2917
+ __m256 acc = _mm256_setzero_ps ();
2918
+ __m128i mask = _mm_set1_epi8 (0x10 );
2919
+
2920
+ float summs = 0.0f ;
2921
+
2922
+ // Main loop
2923
+ for (int i = 0 ; i < nb ; i ++ ) {
2924
+ const __m256 dx = _mm256_set1_ps (GGML_FP16_TO_FP32 (x [i ].d ));
2925
+
2926
+ summs += GGML_FP16_TO_FP32 (x [i ].m ) * y [i ].s ;
2927
+
2928
+ __m256i bx = bytes_from_nibbles_32 (x [i ].qs );
2929
+ const __m256i bxhi = bytes_from_bits_32 (x [i ].qh );
2930
+ __m128i bxhil = _mm256_castsi256_si128 (bxhi );
2931
+ __m128i bxhih = _mm256_extractf128_si256 (bxhi , 1 );
2932
+ bxhil = _mm_and_si128 (bxhil , mask );
2933
+ bxhih = _mm_and_si128 (bxhih , mask );
2934
+ __m128i bxl = _mm256_castsi256_si128 (bx );
2935
+ __m128i bxh = _mm256_extractf128_si256 (bx , 1 );
2936
+ bxl = _mm_or_si128 (bxl , bxhil );
2937
+ bxh = _mm_or_si128 (bxh , bxhih );
2938
+ bx = _mm256_set_m128i (bxh , bxl );
2939
+
2940
+ const __m256 dy = _mm256_broadcast_ss (& y [i ].d );
2941
+ const __m256i by = _mm256_loadu_si256 ((const __m256i * )y [i ].qs );
2942
+
2943
+ const __m256 q = mul_sum_i8_pairs_float (bx , by );
2944
+
2945
+ acc = _mm256_add_ps (_mm256_mul_ps (q , _mm256_mul_ps (dx , dy )), acc );
2946
+ }
2947
+
2823
2948
* s = hsum_float_8 (acc ) + summs ;
2824
2949
#else
2825
2950
// scalar
@@ -2910,7 +3035,7 @@ static void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void *
2910
3035
}
2911
3036
2912
3037
* s = vaddvq_f32 (sumv0 ) + vaddvq_f32 (sumv1 );
2913
- #elif defined(__AVX2__ )
3038
+ #elif defined(__AVX2__ ) || defined( __AVX__ )
2914
3039
// Initialize accumulator with zeros
2915
3040
__m256 acc = _mm256_setzero_ps ();
2916
3041
@@ -2924,7 +3049,11 @@ static void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void *
2924
3049
const __m256 q = mul_sum_i8_pairs_float (bx , by );
2925
3050
2926
3051
// Multiply q with scale and accumulate
3052
+ #if defined(__AVX2__ )
2927
3053
acc = _mm256_fmadd_ps ( d , q , acc );
3054
+ #else
3055
+ acc = _mm256_add_ps ( _mm256_mul_ps ( d , q ), acc );
3056
+ #endif
2928
3057
}
2929
3058
2930
3059
* s = hsum_float_8 (acc );
0 commit comments