@@ -463,12 +463,10 @@ static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float);
463
463
// quantization
464
464
//
465
465
466
- // AVX routines provided by GH user Const-me
467
- // ref: https://github.com/ggerganov/ggml/pull/27#issuecomment-1464934600
468
466
#if __AVX2__ || __AVX512F__
469
467
// Unpack 32 4-bit fields into 32 bytes
470
468
// The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval
471
- static inline __m256i bytesFromNibbles ( const uint8_t * rsi )
469
+ static inline __m256i bytes_from_nibbles_32 ( const uint8_t * rsi )
472
470
{
473
471
// Load 16 bytes from memory
474
472
__m128i tmp = _mm_loadu_si128 ( ( const __m128i * )rsi );
@@ -485,7 +483,7 @@ static inline __m256i bytesFromNibbles( const uint8_t* rsi )
485
483
return bytes ;
486
484
}
487
485
488
- static inline __m128i packNibbles ( __m256i bytes )
486
+ static inline __m128i pack_nibbles_32 ( __m256i bytes )
489
487
{
490
488
// Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh
491
489
const __m256i lowByte = _mm256_set1_epi16 ( 0xFF );
@@ -499,8 +497,11 @@ static inline __m128i packNibbles( __m256i bytes )
499
497
__m128i r1 = _mm256_extracti128_si256 ( bytes , 1 );
500
498
return _mm_packus_epi16 ( r0 , r1 );
501
499
}
502
- #elif __AVX__
503
- static inline __m128i bytesFromNibbles ( const uint8_t * rsi )
500
+ #endif
501
+ #if __AVX__ || __AVX2__ || __AVX512F__
502
+ // Unpack 16 4-bit fields into 16 bytes
503
+ // The output vector contains 16 bytes, each one in [ 0 .. 15 ] interval
504
+ static inline __m128i bytes_from_nibbles_16 (const uint8_t * rsi )
504
505
{
505
506
// Load 8 bytes from memory
506
507
__m128i tmp = _mm_loadu_si64 ( ( const __m128i * )rsi );
@@ -517,7 +518,7 @@ static inline __m128i bytesFromNibbles( const uint8_t* rsi )
517
518
return bytes ;
518
519
}
519
520
520
- static inline __m128i packNibbles ( __m128i bytes1 , __m128i bytes2 )
521
+ static inline __m128i pack_nibbles_16 ( __m128i bytes1 , __m128i bytes2 )
521
522
{
522
523
// Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh
523
524
const __m128i lowByte = _mm_set1_epi16 ( 0xFF );
@@ -820,7 +821,7 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
820
821
i0 = _mm256_add_epi8 ( i0 , off );
821
822
822
823
// Compress the vector into 4 bit/value, and store
823
- __m128i res = packNibbles ( i0 );
824
+ __m128i res = pack_nibbles_32 ( i0 );
824
825
_mm_storeu_si128 ( ( __m128i * )y [i ].qs , res );
825
826
}
826
827
#elif defined(__AVX__ )
@@ -894,7 +895,7 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
894
895
ni4 = _mm_add_epi8 ( ni4 , off );
895
896
896
897
// Compress the vector into 4 bit/value, and store
897
- __m128i res = packNibbles ( ni0 , ni4 );
898
+ __m128i res = pack_nibbles_16 ( ni0 , ni4 );
898
899
_mm_storeu_si128 ( ( __m128i * )y [i ].qs , res );
899
900
}
900
901
#elif defined(__wasm_simd128__ )
@@ -1055,7 +1056,7 @@ static void quantize_row_q4_1(const float * restrict x, void * restrict vy, int
1055
1056
i0 = _mm256_permutevar8x32_epi32 ( i0 , perm );
1056
1057
1057
1058
// Compress the vector into 4 bit/value, and store
1058
- __m128i res = packNibbles ( i0 );
1059
+ __m128i res = pack_nibbles_32 ( i0 );
1059
1060
_mm_storeu_si128 ( ( __m128i * )y [i ].qs , res );
1060
1061
}
1061
1062
#elif __ARM_NEON
@@ -1309,7 +1310,7 @@ static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, in
1309
1310
1310
1311
for (int l = 0 ; l < QK4_0 ; l += 32 ) {
1311
1312
// Load 32x4-bit integers into 32x8-bit integers
1312
- __m256i vx8 = bytesFromNibbles (pp + l /2 );
1313
+ __m256i vx8 = bytes_from_nibbles_32 (pp + l /2 );
1313
1314
1314
1315
// Subtract 8 from the integers
1315
1316
vx8 = _mm256_sub_epi8 (vx8 , _mm256_set1_epi8 (8 ));
@@ -1427,7 +1428,7 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
1427
1428
1428
1429
for (int l = 0 ; l < QK4_1 ; l += 32 ) {
1429
1430
// Load 32x4-bit integers into 32x8-bit integers
1430
- __m256i vx8 = bytesFromNibbles (pp + l /2 );
1431
+ __m256i vx8 = bytes_from_nibbles_32 (pp + l /2 );
1431
1432
1432
1433
// Convert to 16-bit int
1433
1434
const __m256i vx16_lo = _mm256_cvtepi8_epi16 (_mm256_extracti128_si256 (vx8 , 0 ));
@@ -2270,7 +2271,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
2270
2271
/* Compute combined scale for the block */
2271
2272
const __m256 d = _mm256_mul_ps ( _mm256_broadcast_ss ( & x [i ].d ), _mm256_broadcast_ss ( & y [i ].d ) );
2272
2273
2273
- __m256i bx = bytesFromNibbles (x [i ].qs );
2274
+ __m256i bx = bytes_from_nibbles_32 (x [i ].qs );
2274
2275
2275
2276
// Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
2276
2277
const __m256i off = _mm256_set1_epi8 ( 8 );
@@ -2316,7 +2317,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
2316
2317
__m128i i32 [2 ];
2317
2318
for (int j = 0 ; j < 2 ; ++ j ) {
2318
2319
// Load 8 bytes, and unpack 4 bit fields into bytes, making 16 bytes
2319
- __m128i bx = bytesFromNibbles ( x [i ].qs + 8 * j );
2320
+ __m128i bx = bytes_from_nibbles_16 ( x [i ].qs + 8 * j );
2320
2321
__m128i by = _mm_loadu_si128 ((const __m128i * )(y [i ].qs + 16 * j ));
2321
2322
2322
2323
// Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
@@ -2481,7 +2482,7 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
2481
2482
const __m256 d1m0 = _mm256_mul_ps ( d1v , m0v );
2482
2483
2483
2484
// Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes
2484
- const __m256i bx = bytesFromNibbles ( x [i ].qs );
2485
+ const __m256i bx = bytes_from_nibbles_32 ( x [i ].qs );
2485
2486
const __m256i by = _mm256_loadu_si256 ( (const __m256i * )y [i ].qs );
2486
2487
2487
2488
// Get absolute values of x vectors
@@ -2635,6 +2636,51 @@ static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void *
2635
2636
}
2636
2637
2637
2638
sumf = vaddvq_f32 (sumv0 ) + vaddvq_f32 (sumv1 );
2639
+ #elif defined(__AVX2__ )
2640
+ // Initialize accumulator with zeros
2641
+ __m256 acc = _mm256_setzero_ps ();
2642
+
2643
+ // Main loop
2644
+ for (int i = 0 ; i < nb ; i ++ ) {
2645
+ /* Compute combined scale for the block */
2646
+ const __m128 d0 = _mm_set1_ps (GGML_FP16_TO_FP32 (x [2 * i + 0 ].d ));
2647
+ const __m128 d1 = _mm_set1_ps (GGML_FP16_TO_FP32 (x [2 * i + 1 ].d ));
2648
+ const __m256 d = _mm256_mul_ps (_mm256_set_m128 (d1 , d0 ), _mm256_broadcast_ss (& y [i ].d ));
2649
+
2650
+ __m128i bx0 = bytes_from_nibbles_16 (x [2 * i + 0 ].qs );
2651
+ __m128i bx1 = bytes_from_nibbles_16 (x [2 * i + 1 ].qs );
2652
+ __m256i bx = _mm256_set_m128i (bx1 , bx0 );
2653
+
2654
+ // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
2655
+ const __m256i off = _mm256_set1_epi8 (8 );
2656
+ bx = _mm256_sub_epi8 (bx , off );
2657
+
2658
+ __m256i by = _mm256_loadu_si256 ((const __m256i * )y [i ].qs );
2659
+
2660
+ // Get absolute values of x vectors
2661
+ const __m256i ax = _mm256_sign_epi8 (bx , bx );
2662
+ // Sign the values of the y vectors
2663
+ const __m256i sy = _mm256_sign_epi8 (by , bx );
2664
+ // Perform multiplication and create 16-bit values
2665
+ const __m256i dot = _mm256_maddubs_epi16 (ax , sy );
2666
+
2667
+ const __m256i ones = _mm256_set1_epi16 (1 );
2668
+ __m256i xy_q = _mm256_madd_epi16 (ones , dot );
2669
+
2670
+ /* Convert to vectore of 8 int32_t to 8 floats */
2671
+ __m256 q = _mm256_cvtepi32_ps (xy_q );
2672
+
2673
+ /* Multiply q with scale and accumulate */
2674
+ acc = _mm256_fmadd_ps (d , q , acc );
2675
+ }
2676
+
2677
+ // Return horizontal sum of the acc vector
2678
+ __m128 res = _mm256_extractf128_ps (acc , 1 );
2679
+ res = _mm_add_ps (res , _mm256_castps256_ps128 (acc ));
2680
+ res = _mm_add_ps (res , _mm_movehl_ps (res , res ));
2681
+ res = _mm_add_ss (res , _mm_movehdup_ps (res ));
2682
+
2683
+ sumf = _mm_cvtss_f32 (res );
2638
2684
#else
2639
2685
// scalar
2640
2686
for (int i = 0 ; i < nb ; i ++ ) {
0 commit comments