@@ -463,12 +463,30 @@ static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float);
463
463
// quantization
464
464
//
465
465
466
- // AVX routines provided by GH user Const-me
467
- // ref: https://github.com/ggerganov/ggml/pull/27#issuecomment-1464934600
466
+ #if __AVX__ || __AVX2__ || __AVX512F__
467
+ // Unpack 16 4-bit fields into 16 bytes
468
+ // The output vector contains 16 bytes, each one in [ 0 .. 15 ] interval
469
+ static inline __m128i bytes_from_nibbles_16 (const uint8_t * rsi )
470
+ {
471
+ // Load 8 bytes from memory
472
+ __m128i tmp = _mm_loadu_si64 ( ( const __m128i * )rsi );
473
+
474
+ // Expand bytes into uint16_t values
475
+ __m128i bytes = _mm_cvtepu8_epi16 ( tmp );
476
+
477
+ // Unpack values into individual bytes
478
+ const __m128i lowMask = _mm_set1_epi8 ( 0xF );
479
+ __m128i high = _mm_andnot_si128 ( lowMask , bytes );
480
+ __m128i low = _mm_and_si128 ( lowMask , bytes );
481
+ high = _mm_slli_epi16 ( high , 4 );
482
+ bytes = _mm_or_si128 ( low , high );
483
+ return bytes ;
484
+ }
485
+
468
486
#if __AVX2__ || __AVX512F__
469
487
// Unpack 32 4-bit fields into 32 bytes
470
488
// The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval
471
- static inline __m256i bytesFromNibbles ( const uint8_t * rsi )
489
+ static inline __m256i bytes_from_nibbles_32 ( const uint8_t * rsi )
472
490
{
473
491
// Load 16 bytes from memory
474
492
__m128i tmp = _mm_loadu_si128 ( ( const __m128i * )rsi );
@@ -499,24 +517,7 @@ static inline __m128i packNibbles( __m256i bytes )
499
517
__m128i r1 = _mm256_extracti128_si256 ( bytes , 1 );
500
518
return _mm_packus_epi16 ( r0 , r1 );
501
519
}
502
- #elif __AVX__
503
- static inline __m128i bytesFromNibbles ( const uint8_t * rsi )
504
- {
505
- // Load 8 bytes from memory
506
- __m128i tmp = _mm_loadu_si64 ( ( const __m128i * )rsi );
507
-
508
- // Expand bytes into uint16_t values
509
- __m128i bytes = _mm_cvtepu8_epi16 ( tmp );
510
-
511
- // Unpack values into individual bytes
512
- const __m128i lowMask = _mm_set1_epi8 ( 0xF );
513
- __m128i high = _mm_andnot_si128 ( lowMask , bytes );
514
- __m128i low = _mm_and_si128 ( lowMask , bytes );
515
- high = _mm_slli_epi16 ( high , 4 );
516
- bytes = _mm_or_si128 ( low , high );
517
- return bytes ;
518
- }
519
-
520
+ #else
520
521
static inline __m128i packNibbles ( __m128i bytes1 , __m128i bytes2 )
521
522
{
522
523
// Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh
@@ -533,6 +534,7 @@ static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 )
533
534
return _mm_packus_epi16 ( bytes1 , bytes2 );
534
535
}
535
536
#endif
537
+ #endif // __AVX__ || __AVX2__ || __AVX512F__
536
538
537
539
#if __ARM_NEON
538
540
@@ -1309,7 +1311,7 @@ static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, in
1309
1311
1310
1312
for (int l = 0 ; l < QK4_0 ; l += 32 ) {
1311
1313
// Load 32x4-bit integers into 32x8-bit integers
1312
- __m256i vx8 = bytesFromNibbles (pp + l /2 );
1314
+ __m256i vx8 = bytes_from_nibbles_32 (pp + l /2 );
1313
1315
1314
1316
// Subtract 8 from the integers
1315
1317
vx8 = _mm256_sub_epi8 (vx8 , _mm256_set1_epi8 (8 ));
@@ -1427,7 +1429,7 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
1427
1429
1428
1430
for (int l = 0 ; l < QK4_1 ; l += 32 ) {
1429
1431
// Load 32x4-bit integers into 32x8-bit integers
1430
- __m256i vx8 = bytesFromNibbles (pp + l /2 );
1432
+ __m256i vx8 = bytes_from_nibbles_32 (pp + l /2 );
1431
1433
1432
1434
// Convert to 16-bit int
1433
1435
const __m256i vx16_lo = _mm256_cvtepi8_epi16 (_mm256_extracti128_si256 (vx8 , 0 ));
@@ -2270,7 +2272,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
2270
2272
/* Compute combined scale for the block */
2271
2273
const __m256 d = _mm256_mul_ps ( _mm256_broadcast_ss ( & x [i ].d ), _mm256_broadcast_ss ( & y [i ].d ) );
2272
2274
2273
- __m256i bx = bytesFromNibbles (x [i ].qs );
2275
+ __m256i bx = bytes_from_nibbles_32 (x [i ].qs );
2274
2276
2275
2277
// Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
2276
2278
const __m256i off = _mm256_set1_epi8 ( 8 );
@@ -2316,7 +2318,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
2316
2318
__m128i i32 [2 ];
2317
2319
for (int j = 0 ; j < 2 ; ++ j ) {
2318
2320
// Load 8 bytes, and unpack 4 bit fields into bytes, making 16 bytes
2319
- __m128i bx = bytesFromNibbles ( x [i ].qs + 8 * j );
2321
+ __m128i bx = bytes_from_nibbles_16 ( x [i ].qs + 8 * j );
2320
2322
__m128i by = _mm_loadu_si128 ((const __m128i * )(y [i ].qs + 16 * j ));
2321
2323
2322
2324
// Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
@@ -2481,7 +2483,7 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
2481
2483
const __m256 d1m0 = _mm256_mul_ps ( d1v , m0v );
2482
2484
2483
2485
// Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes
2484
- const __m256i bx = bytesFromNibbles ( x [i ].qs );
2486
+ const __m256i bx = bytes_from_nibbles_32 ( x [i ].qs );
2485
2487
const __m256i by = _mm256_loadu_si256 ( (const __m256i * )y [i ].qs );
2486
2488
2487
2489
// Get absolute values of x vectors
@@ -2635,6 +2637,51 @@ static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void *
2635
2637
}
2636
2638
2637
2639
sumf = vaddvq_f32 (sumv0 ) + vaddvq_f32 (sumv1 );
2640
+ #elif defined(__AVX2__ )
2641
+ // Initialize accumulator with zeros
2642
+ __m256 acc = _mm256_setzero_ps ();
2643
+
2644
+ // Main loop
2645
+ for (int i = 0 ; i < nb ; i ++ ) {
2646
+ /* Compute combined scale for the block */
2647
+ const __m128 d0 = _mm_set1_ps (GGML_FP16_TO_FP32 (x [2 * i + 0 ].d ));
2648
+ const __m128 d1 = _mm_set1_ps (GGML_FP16_TO_FP32 (x [2 * i + 1 ].d ));
2649
+ const __m256 d = _mm256_mul_ps (_mm256_set_m128 (d1 , d0 ), _mm256_broadcast_ss (& y [i ].d ));
2650
+
2651
+ __m128i bx0 = bytes_from_nibbles_16 (x [2 * i + 0 ].qs );
2652
+ __m128i bx1 = bytes_from_nibbles_16 (x [2 * i + 1 ].qs );
2653
+ __m256i bx = _mm256_set_m128i (bx1 , bx0 );
2654
+
2655
+ // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
2656
+ const __m256i off = _mm256_set1_epi8 (8 );
2657
+ bx = _mm256_sub_epi8 (bx , off );
2658
+
2659
+ __m256i by = _mm256_loadu_si256 ((const __m256i * )y [i ].qs );
2660
+
2661
+ // Get absolute values of x vectors
2662
+ const __m256i ax = _mm256_sign_epi8 (bx , bx );
2663
+ // Sign the values of the y vectors
2664
+ const __m256i sy = _mm256_sign_epi8 (by , bx );
2665
+ // Perform multiplication and create 16-bit values
2666
+ const __m256i dot = _mm256_maddubs_epi16 (ax , sy );
2667
+
2668
+ const __m256i ones = _mm256_set1_epi16 (1 );
2669
+ __m256i xy_q = _mm256_madd_epi16 (ones , dot );
2670
+
2671
+ /* Convert to vectore of 8 int32_t to 8 floats */
2672
+ __m256 q = _mm256_cvtepi32_ps (xy_q );
2673
+
2674
+ /* Multiply q with scale and accumulate */
2675
+ acc = _mm256_fmadd_ps (d , q , acc );
2676
+ }
2677
+
2678
+ // Return horizontal sum of the acc vector
2679
+ __m128 res = _mm256_extractf128_ps (acc , 1 );
2680
+ res = _mm_add_ps (res , _mm256_castps256_ps128 (acc ));
2681
+ res = _mm_add_ps (res , _mm_movehl_ps (res , res ));
2682
+ res = _mm_add_ss (res , _mm_movehdup_ps (res ));
2683
+
2684
+ sumf = _mm_cvtss_f32 (res );
2638
2685
#else
2639
2686
// scalar
2640
2687
for (int i = 0 ; i < nb ; i ++ ) {
0 commit comments