Skip to content

Commit c8c2c52

Browse files
authored
AVX2 optimization for vec_dot_q4_2_q8_0 (#1068)
1 parent 02d6988 commit c8c2c52

File tree

1 file changed

+73
-26
lines changed

1 file changed

+73
-26
lines changed

ggml.c

Lines changed: 73 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -467,12 +467,30 @@ static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float);
467467
// quantization
468468
//
469469

470-
// AVX routines provided by GH user Const-me
471-
// ref: https://github.com/ggerganov/ggml/pull/27#issuecomment-1464934600
470+
#if __AVX__ || __AVX2__ || __AVX512F__
471+
// Unpack 16 4-bit fields into 16 bytes
472+
// The output vector contains 16 bytes, each one in [ 0 .. 15 ] interval
473+
static inline __m128i bytes_from_nibbles_16(const uint8_t * rsi)
474+
{
475+
// Load 8 bytes from memory
476+
__m128i tmp = _mm_loadu_si64( ( const __m128i* )rsi );
477+
478+
// Expand bytes into uint16_t values
479+
__m128i bytes = _mm_cvtepu8_epi16( tmp );
480+
481+
// Unpack values into individual bytes
482+
const __m128i lowMask = _mm_set1_epi8( 0xF );
483+
__m128i high = _mm_andnot_si128( lowMask, bytes );
484+
__m128i low = _mm_and_si128( lowMask, bytes );
485+
high = _mm_slli_epi16( high, 4 );
486+
bytes = _mm_or_si128( low, high );
487+
return bytes;
488+
}
489+
472490
#if __AVX2__ || __AVX512F__
473491
// Unpack 32 4-bit fields into 32 bytes
474492
// The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval
475-
static inline __m256i bytesFromNibbles( const uint8_t* rsi )
493+
static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi)
476494
{
477495
// Load 16 bytes from memory
478496
__m128i tmp = _mm_loadu_si128( ( const __m128i* )rsi );
@@ -503,24 +521,7 @@ static inline __m128i packNibbles( __m256i bytes )
503521
__m128i r1 = _mm256_extracti128_si256( bytes, 1 );
504522
return _mm_packus_epi16( r0, r1 );
505523
}
506-
#elif __AVX__
507-
static inline __m128i bytesFromNibbles( const uint8_t* rsi )
508-
{
509-
// Load 8 bytes from memory
510-
__m128i tmp = _mm_loadu_si64( ( const __m128i* )rsi );
511-
512-
// Expand bytes into uint16_t values
513-
__m128i bytes = _mm_cvtepu8_epi16( tmp );
514-
515-
// Unpack values into individual bytes
516-
const __m128i lowMask = _mm_set1_epi8( 0xF );
517-
__m128i high = _mm_andnot_si128( lowMask, bytes );
518-
__m128i low = _mm_and_si128( lowMask, bytes );
519-
high = _mm_slli_epi16( high, 4 );
520-
bytes = _mm_or_si128( low, high );
521-
return bytes;
522-
}
523-
524+
#else
524525
static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 )
525526
{
526527
// Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh
@@ -537,6 +538,7 @@ static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 )
537538
return _mm_packus_epi16( bytes1, bytes2);
538539
}
539540
#endif
541+
#endif // __AVX__ || __AVX2__ || __AVX512F__
540542

541543
#if __ARM_NEON
542544

@@ -1395,7 +1397,7 @@ static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, in
13951397

13961398
for (int l = 0; l < QK4_0; l += 32) {
13971399
// Load 32x4-bit integers into 32x8-bit integers
1398-
__m256i vx8 = bytesFromNibbles(pp+l/2);
1400+
__m256i vx8 = bytes_from_nibbles_32(pp+l/2);
13991401

14001402
// Subtract 8 from the integers
14011403
vx8 = _mm256_sub_epi8(vx8, _mm256_set1_epi8(8));
@@ -1513,7 +1515,7 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
15131515

15141516
for (int l = 0; l < QK4_1; l += 32) {
15151517
// Load 32x4-bit integers into 32x8-bit integers
1516-
__m256i vx8 = bytesFromNibbles(pp+l/2);
1518+
__m256i vx8 = bytes_from_nibbles_32(pp+l/2);
15171519

15181520
// Convert to 16-bit int
15191521
const __m256i vx16_lo = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(vx8, 0));
@@ -2356,7 +2358,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
23562358
/* Compute combined scale for the block */
23572359
const __m256 d = _mm256_mul_ps( _mm256_broadcast_ss( &x[i].d ), _mm256_broadcast_ss( &y[i].d ) );
23582360

2359-
__m256i bx = bytesFromNibbles(x[i].qs);
2361+
__m256i bx = bytes_from_nibbles_32(x[i].qs);
23602362

23612363
// Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
23622364
const __m256i off = _mm256_set1_epi8( 8 );
@@ -2402,7 +2404,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
24022404
__m128i i32[2];
24032405
for (int j = 0; j < 2; ++j) {
24042406
// Load 8 bytes, and unpack 4 bit fields into bytes, making 16 bytes
2405-
__m128i bx = bytesFromNibbles( x[i].qs + 8*j );
2407+
__m128i bx = bytes_from_nibbles_16(x[i].qs + 8*j);
24062408
__m128i by = _mm_loadu_si128((const __m128i *)(y[i].qs + 16*j));
24072409

24082410
// Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
@@ -2567,7 +2569,7 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
25672569
const __m256 d1m0 = _mm256_mul_ps( d1v, m0v );
25682570

25692571
// Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes
2570-
const __m256i bx = bytesFromNibbles( x[i].qs );
2572+
const __m256i bx = bytes_from_nibbles_32(x[i].qs);
25712573
const __m256i by = _mm256_loadu_si256( (const __m256i *)y[i].qs );
25722574

25732575
// Get absolute values of x vectors
@@ -2721,6 +2723,51 @@ static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void *
27212723
}
27222724

27232725
sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
2726+
#elif defined(__AVX2__)
2727+
// Initialize accumulator with zeros
2728+
__m256 acc = _mm256_setzero_ps();
2729+
2730+
// Main loop
2731+
for (int i = 0; i < nb; i++) {
2732+
/* Compute combined scale for the block */
2733+
const __m128 d0 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 0].d));
2734+
const __m128 d1 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 1].d));
2735+
const __m256 d = _mm256_mul_ps(_mm256_set_m128(d1, d0), _mm256_broadcast_ss(&y[i].d));
2736+
2737+
__m128i bx0 = bytes_from_nibbles_16(x[2*i + 0].qs);
2738+
__m128i bx1 = bytes_from_nibbles_16(x[2*i + 1].qs);
2739+
__m256i bx = _mm256_set_m128i(bx1, bx0);
2740+
2741+
// Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
2742+
const __m256i off = _mm256_set1_epi8(8);
2743+
bx = _mm256_sub_epi8(bx, off);
2744+
2745+
__m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
2746+
2747+
// Get absolute values of x vectors
2748+
const __m256i ax = _mm256_sign_epi8(bx, bx);
2749+
// Sign the values of the y vectors
2750+
const __m256i sy = _mm256_sign_epi8(by, bx);
2751+
// Perform multiplication and create 16-bit values
2752+
const __m256i dot = _mm256_maddubs_epi16(ax, sy);
2753+
2754+
const __m256i ones = _mm256_set1_epi16(1);
2755+
__m256i xy_q = _mm256_madd_epi16(ones, dot);
2756+
2757+
/* Convert to vectore of 8 int32_t to 8 floats */
2758+
__m256 q = _mm256_cvtepi32_ps(xy_q);
2759+
2760+
/* Multiply q with scale and accumulate */
2761+
acc = _mm256_fmadd_ps(d, q, acc);
2762+
}
2763+
2764+
// Return horizontal sum of the acc vector
2765+
__m128 res = _mm256_extractf128_ps(acc, 1);
2766+
res = _mm_add_ps(res, _mm256_castps256_ps128(acc));
2767+
res = _mm_add_ps(res, _mm_movehl_ps(res, res));
2768+
res = _mm_add_ss(res, _mm_movehdup_ps(res));
2769+
2770+
sumf = _mm_cvtss_f32(res);
27242771
#else
27252772
// scalar
27262773
for (int i = 0; i < nb; i++) {

0 commit comments

Comments
 (0)