@@ -467,12 +467,30 @@ static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float);
467
467
// quantization
468
468
//
469
469
470
- // AVX routines provided by GH user Const-me
471
- // ref: https://github.com/ggerganov/ggml/pull/27#issuecomment-1464934600
470
+ #if __AVX__ || __AVX2__ || __AVX512F__
471
+ // Unpack 16 4-bit fields into 16 bytes
472
+ // The output vector contains 16 bytes, each one in [ 0 .. 15 ] interval
473
+ static inline __m128i bytes_from_nibbles_16 (const uint8_t * rsi )
474
+ {
475
+ // Load 8 bytes from memory
476
+ __m128i tmp = _mm_loadu_si64 ( ( const __m128i * )rsi );
477
+
478
+ // Expand bytes into uint16_t values
479
+ __m128i bytes = _mm_cvtepu8_epi16 ( tmp );
480
+
481
+ // Unpack values into individual bytes
482
+ const __m128i lowMask = _mm_set1_epi8 ( 0xF );
483
+ __m128i high = _mm_andnot_si128 ( lowMask , bytes );
484
+ __m128i low = _mm_and_si128 ( lowMask , bytes );
485
+ high = _mm_slli_epi16 ( high , 4 );
486
+ bytes = _mm_or_si128 ( low , high );
487
+ return bytes ;
488
+ }
489
+
472
490
#if __AVX2__ || __AVX512F__
473
491
// Unpack 32 4-bit fields into 32 bytes
474
492
// The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval
475
- static inline __m256i bytesFromNibbles ( const uint8_t * rsi )
493
+ static inline __m256i bytes_from_nibbles_32 ( const uint8_t * rsi )
476
494
{
477
495
// Load 16 bytes from memory
478
496
__m128i tmp = _mm_loadu_si128 ( ( const __m128i * )rsi );
@@ -503,24 +521,7 @@ static inline __m128i packNibbles( __m256i bytes )
503
521
__m128i r1 = _mm256_extracti128_si256 ( bytes , 1 );
504
522
return _mm_packus_epi16 ( r0 , r1 );
505
523
}
506
- #elif __AVX__
507
- static inline __m128i bytesFromNibbles ( const uint8_t * rsi )
508
- {
509
- // Load 8 bytes from memory
510
- __m128i tmp = _mm_loadu_si64 ( ( const __m128i * )rsi );
511
-
512
- // Expand bytes into uint16_t values
513
- __m128i bytes = _mm_cvtepu8_epi16 ( tmp );
514
-
515
- // Unpack values into individual bytes
516
- const __m128i lowMask = _mm_set1_epi8 ( 0xF );
517
- __m128i high = _mm_andnot_si128 ( lowMask , bytes );
518
- __m128i low = _mm_and_si128 ( lowMask , bytes );
519
- high = _mm_slli_epi16 ( high , 4 );
520
- bytes = _mm_or_si128 ( low , high );
521
- return bytes ;
522
- }
523
-
524
+ #else
524
525
static inline __m128i packNibbles ( __m128i bytes1 , __m128i bytes2 )
525
526
{
526
527
// Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh
@@ -537,6 +538,7 @@ static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 )
537
538
return _mm_packus_epi16 ( bytes1 , bytes2 );
538
539
}
539
540
#endif
541
+ #endif // __AVX__ || __AVX2__ || __AVX512F__
540
542
541
543
#if __ARM_NEON
542
544
@@ -1395,7 +1397,7 @@ static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, in
1395
1397
1396
1398
for (int l = 0 ; l < QK4_0 ; l += 32 ) {
1397
1399
// Load 32x4-bit integers into 32x8-bit integers
1398
- __m256i vx8 = bytesFromNibbles (pp + l /2 );
1400
+ __m256i vx8 = bytes_from_nibbles_32 (pp + l /2 );
1399
1401
1400
1402
// Subtract 8 from the integers
1401
1403
vx8 = _mm256_sub_epi8 (vx8 , _mm256_set1_epi8 (8 ));
@@ -1513,7 +1515,7 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
1513
1515
1514
1516
for (int l = 0 ; l < QK4_1 ; l += 32 ) {
1515
1517
// Load 32x4-bit integers into 32x8-bit integers
1516
- __m256i vx8 = bytesFromNibbles (pp + l /2 );
1518
+ __m256i vx8 = bytes_from_nibbles_32 (pp + l /2 );
1517
1519
1518
1520
// Convert to 16-bit int
1519
1521
const __m256i vx16_lo = _mm256_cvtepi8_epi16 (_mm256_extracti128_si256 (vx8 , 0 ));
@@ -2356,7 +2358,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
2356
2358
/* Compute combined scale for the block */
2357
2359
const __m256 d = _mm256_mul_ps ( _mm256_broadcast_ss ( & x [i ].d ), _mm256_broadcast_ss ( & y [i ].d ) );
2358
2360
2359
- __m256i bx = bytesFromNibbles (x [i ].qs );
2361
+ __m256i bx = bytes_from_nibbles_32 (x [i ].qs );
2360
2362
2361
2363
// Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
2362
2364
const __m256i off = _mm256_set1_epi8 ( 8 );
@@ -2402,7 +2404,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
2402
2404
__m128i i32 [2 ];
2403
2405
for (int j = 0 ; j < 2 ; ++ j ) {
2404
2406
// Load 8 bytes, and unpack 4 bit fields into bytes, making 16 bytes
2405
- __m128i bx = bytesFromNibbles ( x [i ].qs + 8 * j );
2407
+ __m128i bx = bytes_from_nibbles_16 ( x [i ].qs + 8 * j );
2406
2408
__m128i by = _mm_loadu_si128 ((const __m128i * )(y [i ].qs + 16 * j ));
2407
2409
2408
2410
// Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
@@ -2567,7 +2569,7 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
2567
2569
const __m256 d1m0 = _mm256_mul_ps ( d1v , m0v );
2568
2570
2569
2571
// Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes
2570
- const __m256i bx = bytesFromNibbles ( x [i ].qs );
2572
+ const __m256i bx = bytes_from_nibbles_32 ( x [i ].qs );
2571
2573
const __m256i by = _mm256_loadu_si256 ( (const __m256i * )y [i ].qs );
2572
2574
2573
2575
// Get absolute values of x vectors
@@ -2721,6 +2723,51 @@ static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void *
2721
2723
}
2722
2724
2723
2725
sumf = vaddvq_f32 (sumv0 ) + vaddvq_f32 (sumv1 );
2726
+ #elif defined(__AVX2__ )
2727
+ // Initialize accumulator with zeros
2728
+ __m256 acc = _mm256_setzero_ps ();
2729
+
2730
+ // Main loop
2731
+ for (int i = 0 ; i < nb ; i ++ ) {
2732
+ /* Compute combined scale for the block */
2733
+ const __m128 d0 = _mm_set1_ps (GGML_FP16_TO_FP32 (x [2 * i + 0 ].d ));
2734
+ const __m128 d1 = _mm_set1_ps (GGML_FP16_TO_FP32 (x [2 * i + 1 ].d ));
2735
+ const __m256 d = _mm256_mul_ps (_mm256_set_m128 (d1 , d0 ), _mm256_broadcast_ss (& y [i ].d ));
2736
+
2737
+ __m128i bx0 = bytes_from_nibbles_16 (x [2 * i + 0 ].qs );
2738
+ __m128i bx1 = bytes_from_nibbles_16 (x [2 * i + 1 ].qs );
2739
+ __m256i bx = _mm256_set_m128i (bx1 , bx0 );
2740
+
2741
+ // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
2742
+ const __m256i off = _mm256_set1_epi8 (8 );
2743
+ bx = _mm256_sub_epi8 (bx , off );
2744
+
2745
+ __m256i by = _mm256_loadu_si256 ((const __m256i * )y [i ].qs );
2746
+
2747
+ // Get absolute values of x vectors
2748
+ const __m256i ax = _mm256_sign_epi8 (bx , bx );
2749
+ // Sign the values of the y vectors
2750
+ const __m256i sy = _mm256_sign_epi8 (by , bx );
2751
+ // Perform multiplication and create 16-bit values
2752
+ const __m256i dot = _mm256_maddubs_epi16 (ax , sy );
2753
+
2754
+ const __m256i ones = _mm256_set1_epi16 (1 );
2755
+ __m256i xy_q = _mm256_madd_epi16 (ones , dot );
2756
+
2757
+ /* Convert to vectore of 8 int32_t to 8 floats */
2758
+ __m256 q = _mm256_cvtepi32_ps (xy_q );
2759
+
2760
+ /* Multiply q with scale and accumulate */
2761
+ acc = _mm256_fmadd_ps (d , q , acc );
2762
+ }
2763
+
2764
+ // Return horizontal sum of the acc vector
2765
+ __m128 res = _mm256_extractf128_ps (acc , 1 );
2766
+ res = _mm_add_ps (res , _mm256_castps256_ps128 (acc ));
2767
+ res = _mm_add_ps (res , _mm_movehl_ps (res , res ));
2768
+ res = _mm_add_ss (res , _mm_movehdup_ps (res ));
2769
+
2770
+ sumf = _mm_cvtss_f32 (res );
2724
2771
#else
2725
2772
// scalar
2726
2773
for (int i = 0 ; i < nb ; i ++ ) {
0 commit comments