Skip to content

Commit 7202864

Browse files
committed
AVX2 optimization for vec_dot_q4_2_q8_0
1 parent 884e7d7 commit 7202864

File tree

1 file changed

+73
-26
lines changed

1 file changed

+73
-26
lines changed

ggml.c

Lines changed: 73 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -463,12 +463,30 @@ static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float);
463463
// quantization
464464
//
465465

466-
// AVX routines provided by GH user Const-me
467-
// ref: https://github.com/ggerganov/ggml/pull/27#issuecomment-1464934600
466+
#if __AVX__ || __AVX2__ || __AVX512F__
467+
// Unpack 16 4-bit fields into 16 bytes
468+
// The output vector contains 16 bytes, each one in [ 0 .. 15 ] interval
469+
static inline __m128i bytes_from_nibbles_16(const uint8_t * rsi)
470+
{
471+
// Load 8 bytes from memory
472+
__m128i tmp = _mm_loadu_si64( ( const __m128i* )rsi );
473+
474+
// Expand bytes into uint16_t values
475+
__m128i bytes = _mm_cvtepu8_epi16( tmp );
476+
477+
// Unpack values into individual bytes
478+
const __m128i lowMask = _mm_set1_epi8( 0xF );
479+
__m128i high = _mm_andnot_si128( lowMask, bytes );
480+
__m128i low = _mm_and_si128( lowMask, bytes );
481+
high = _mm_slli_epi16( high, 4 );
482+
bytes = _mm_or_si128( low, high );
483+
return bytes;
484+
}
485+
468486
#if __AVX2__ || __AVX512F__
469487
// Unpack 32 4-bit fields into 32 bytes
470488
// The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval
471-
static inline __m256i bytesFromNibbles( const uint8_t* rsi )
489+
static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi)
472490
{
473491
// Load 16 bytes from memory
474492
__m128i tmp = _mm_loadu_si128( ( const __m128i* )rsi );
@@ -499,24 +517,7 @@ static inline __m128i packNibbles( __m256i bytes )
499517
__m128i r1 = _mm256_extracti128_si256( bytes, 1 );
500518
return _mm_packus_epi16( r0, r1 );
501519
}
502-
#elif __AVX__
503-
static inline __m128i bytesFromNibbles( const uint8_t* rsi )
504-
{
505-
// Load 8 bytes from memory
506-
__m128i tmp = _mm_loadu_si64( ( const __m128i* )rsi );
507-
508-
// Expand bytes into uint16_t values
509-
__m128i bytes = _mm_cvtepu8_epi16( tmp );
510-
511-
// Unpack values into individual bytes
512-
const __m128i lowMask = _mm_set1_epi8( 0xF );
513-
__m128i high = _mm_andnot_si128( lowMask, bytes );
514-
__m128i low = _mm_and_si128( lowMask, bytes );
515-
high = _mm_slli_epi16( high, 4 );
516-
bytes = _mm_or_si128( low, high );
517-
return bytes;
518-
}
519-
520+
#else
520521
static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 )
521522
{
522523
// Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh
@@ -533,6 +534,7 @@ static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 )
533534
return _mm_packus_epi16( bytes1, bytes2);
534535
}
535536
#endif
537+
#endif // __AVX__ || __AVX2__ || __AVX512F__
536538

537539
#if __ARM_NEON
538540

@@ -1309,7 +1311,7 @@ static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, in
13091311

13101312
for (int l = 0; l < QK4_0; l += 32) {
13111313
// Load 32x4-bit integers into 32x8-bit integers
1312-
__m256i vx8 = bytesFromNibbles(pp+l/2);
1314+
__m256i vx8 = bytes_from_nibbles_32(pp+l/2);
13131315

13141316
// Subtract 8 from the integers
13151317
vx8 = _mm256_sub_epi8(vx8, _mm256_set1_epi8(8));
@@ -1427,7 +1429,7 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
14271429

14281430
for (int l = 0; l < QK4_1; l += 32) {
14291431
// Load 32x4-bit integers into 32x8-bit integers
1430-
__m256i vx8 = bytesFromNibbles(pp+l/2);
1432+
__m256i vx8 = bytes_from_nibbles_32(pp+l/2);
14311433

14321434
// Convert to 16-bit int
14331435
const __m256i vx16_lo = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(vx8, 0));
@@ -2270,7 +2272,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
22702272
/* Compute combined scale for the block */
22712273
const __m256 d = _mm256_mul_ps( _mm256_broadcast_ss( &x[i].d ), _mm256_broadcast_ss( &y[i].d ) );
22722274

2273-
__m256i bx = bytesFromNibbles(x[i].qs);
2275+
__m256i bx = bytes_from_nibbles_32(x[i].qs);
22742276

22752277
// Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
22762278
const __m256i off = _mm256_set1_epi8( 8 );
@@ -2316,7 +2318,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
23162318
__m128i i32[2];
23172319
for (int j = 0; j < 2; ++j) {
23182320
// Load 8 bytes, and unpack 4 bit fields into bytes, making 16 bytes
2319-
__m128i bx = bytesFromNibbles( x[i].qs + 8*j );
2321+
__m128i bx = bytes_from_nibbles_16(x[i].qs + 8*j);
23202322
__m128i by = _mm_loadu_si128((const __m128i *)(y[i].qs + 16*j));
23212323

23222324
// Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
@@ -2481,7 +2483,7 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
24812483
const __m256 d1m0 = _mm256_mul_ps( d1v, m0v );
24822484

24832485
// Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes
2484-
const __m256i bx = bytesFromNibbles( x[i].qs );
2486+
const __m256i bx = bytes_from_nibbles_32(x[i].qs);
24852487
const __m256i by = _mm256_loadu_si256( (const __m256i *)y[i].qs );
24862488

24872489
// Get absolute values of x vectors
@@ -2635,6 +2637,51 @@ static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void *
26352637
}
26362638

26372639
sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
2640+
#elif defined(__AVX2__)
2641+
// Initialize accumulator with zeros
2642+
__m256 acc = _mm256_setzero_ps();
2643+
2644+
// Main loop
2645+
for (int i = 0; i < nb; i++) {
2646+
/* Compute combined scale for the block */
2647+
const __m128 d0 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 0].d));
2648+
const __m128 d1 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 1].d));
2649+
const __m256 d = _mm256_mul_ps(_mm256_set_m128(d1, d0), _mm256_broadcast_ss(&y[i].d));
2650+
2651+
__m128i bx0 = bytes_from_nibbles_16(x[2*i + 0].qs);
2652+
__m128i bx1 = bytes_from_nibbles_16(x[2*i + 1].qs);
2653+
__m256i bx = _mm256_set_m128i(bx1, bx0);
2654+
2655+
// Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
2656+
const __m256i off = _mm256_set1_epi8(8);
2657+
bx = _mm256_sub_epi8(bx, off);
2658+
2659+
__m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
2660+
2661+
// Get absolute values of x vectors
2662+
const __m256i ax = _mm256_sign_epi8(bx, bx);
2663+
// Sign the values of the y vectors
2664+
const __m256i sy = _mm256_sign_epi8(by, bx);
2665+
// Perform multiplication and create 16-bit values
2666+
const __m256i dot = _mm256_maddubs_epi16(ax, sy);
2667+
2668+
const __m256i ones = _mm256_set1_epi16(1);
2669+
__m256i xy_q = _mm256_madd_epi16(ones, dot);
2670+
2671+
/* Convert to vectore of 8 int32_t to 8 floats */
2672+
__m256 q = _mm256_cvtepi32_ps(xy_q);
2673+
2674+
/* Multiply q with scale and accumulate */
2675+
acc = _mm256_fmadd_ps(d, q, acc);
2676+
}
2677+
2678+
// Return horizontal sum of the acc vector
2679+
__m128 res = _mm256_extractf128_ps(acc, 1);
2680+
res = _mm_add_ps(res, _mm256_castps256_ps128(acc));
2681+
res = _mm_add_ps(res, _mm_movehl_ps(res, res));
2682+
res = _mm_add_ss(res, _mm_movehdup_ps(res));
2683+
2684+
sumf = _mm_cvtss_f32(res);
26382685
#else
26392686
// scalar
26402687
for (int i = 0; i < nb; i++) {

0 commit comments

Comments
 (0)