Skip to content

Commit c25caa6

Browse files
committed
AVX2 optimization for vec_dot_q4_2_q8_0
1 parent 884e7d7 commit c25caa6

File tree

1 file changed

+61
-15
lines changed

1 file changed

+61
-15
lines changed

ggml.c

Lines changed: 61 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -463,12 +463,10 @@ static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float);
463463
// quantization
464464
//
465465

466-
// AVX routines provided by GH user Const-me
467-
// ref: https://github.com/ggerganov/ggml/pull/27#issuecomment-1464934600
468466
#if __AVX2__ || __AVX512F__
469467
// Unpack 32 4-bit fields into 32 bytes
470468
// The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval
471-
static inline __m256i bytesFromNibbles( const uint8_t* rsi )
469+
static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi)
472470
{
473471
// Load 16 bytes from memory
474472
__m128i tmp = _mm_loadu_si128( ( const __m128i* )rsi );
@@ -485,7 +483,7 @@ static inline __m256i bytesFromNibbles( const uint8_t* rsi )
485483
return bytes;
486484
}
487485

488-
static inline __m128i packNibbles( __m256i bytes )
486+
static inline __m128i pack_nibbles_32(__m256i bytes)
489487
{
490488
// Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh
491489
const __m256i lowByte = _mm256_set1_epi16( 0xFF );
@@ -499,8 +497,11 @@ static inline __m128i packNibbles( __m256i bytes )
499497
__m128i r1 = _mm256_extracti128_si256( bytes, 1 );
500498
return _mm_packus_epi16( r0, r1 );
501499
}
502-
#elif __AVX__
503-
static inline __m128i bytesFromNibbles( const uint8_t* rsi )
500+
#endif
501+
#if __AVX__ || __AVX2__ || __AVX512F__
502+
// Unpack 16 4-bit fields into 16 bytes
503+
// The output vector contains 16 bytes, each one in [ 0 .. 15 ] interval
504+
static inline __m128i bytes_from_nibbles_16(const uint8_t * rsi)
504505
{
505506
// Load 8 bytes from memory
506507
__m128i tmp = _mm_loadu_si64( ( const __m128i* )rsi );
@@ -517,7 +518,7 @@ static inline __m128i bytesFromNibbles( const uint8_t* rsi )
517518
return bytes;
518519
}
519520

520-
static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 )
521+
static inline __m128i pack_nibbles_16(__m128i bytes1, __m128i bytes2)
521522
{
522523
// Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh
523524
const __m128i lowByte = _mm_set1_epi16( 0xFF );
@@ -820,7 +821,7 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
820821
i0 = _mm256_add_epi8( i0, off );
821822

822823
// Compress the vector into 4 bit/value, and store
823-
__m128i res = packNibbles( i0 );
824+
__m128i res = pack_nibbles_32(i0);
824825
_mm_storeu_si128( ( __m128i* )y[i].qs, res );
825826
}
826827
#elif defined(__AVX__)
@@ -894,7 +895,7 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
894895
ni4 = _mm_add_epi8( ni4, off );
895896

896897
// Compress the vector into 4 bit/value, and store
897-
__m128i res = packNibbles( ni0, ni4 );
898+
__m128i res = pack_nibbles_16(ni0, ni4);
898899
_mm_storeu_si128( ( __m128i* )y[i].qs, res );
899900
}
900901
#elif defined(__wasm_simd128__)
@@ -1055,7 +1056,7 @@ static void quantize_row_q4_1(const float * restrict x, void * restrict vy, int
10551056
i0 = _mm256_permutevar8x32_epi32( i0, perm );
10561057

10571058
// Compress the vector into 4 bit/value, and store
1058-
__m128i res = packNibbles( i0 );
1059+
__m128i res = pack_nibbles_32(i0);
10591060
_mm_storeu_si128( ( __m128i* )y[i].qs, res );
10601061
}
10611062
#elif __ARM_NEON
@@ -1309,7 +1310,7 @@ static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, in
13091310

13101311
for (int l = 0; l < QK4_0; l += 32) {
13111312
// Load 32x4-bit integers into 32x8-bit integers
1312-
__m256i vx8 = bytesFromNibbles(pp+l/2);
1313+
__m256i vx8 = bytes_from_nibbles_32(pp+l/2);
13131314

13141315
// Subtract 8 from the integers
13151316
vx8 = _mm256_sub_epi8(vx8, _mm256_set1_epi8(8));
@@ -1427,7 +1428,7 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
14271428

14281429
for (int l = 0; l < QK4_1; l += 32) {
14291430
// Load 32x4-bit integers into 32x8-bit integers
1430-
__m256i vx8 = bytesFromNibbles(pp+l/2);
1431+
__m256i vx8 = bytes_from_nibbles_32(pp+l/2);
14311432

14321433
// Convert to 16-bit int
14331434
const __m256i vx16_lo = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(vx8, 0));
@@ -2270,7 +2271,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
22702271
/* Compute combined scale for the block */
22712272
const __m256 d = _mm256_mul_ps( _mm256_broadcast_ss( &x[i].d ), _mm256_broadcast_ss( &y[i].d ) );
22722273

2273-
__m256i bx = bytesFromNibbles(x[i].qs);
2274+
__m256i bx = bytes_from_nibbles_32(x[i].qs);
22742275

22752276
// Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
22762277
const __m256i off = _mm256_set1_epi8( 8 );
@@ -2316,7 +2317,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
23162317
__m128i i32[2];
23172318
for (int j = 0; j < 2; ++j) {
23182319
// Load 8 bytes, and unpack 4 bit fields into bytes, making 16 bytes
2319-
__m128i bx = bytesFromNibbles( x[i].qs + 8*j );
2320+
__m128i bx = bytes_from_nibbles_16(x[i].qs + 8*j);
23202321
__m128i by = _mm_loadu_si128((const __m128i *)(y[i].qs + 16*j));
23212322

23222323
// Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
@@ -2481,7 +2482,7 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
24812482
const __m256 d1m0 = _mm256_mul_ps( d1v, m0v );
24822483

24832484
// Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes
2484-
const __m256i bx = bytesFromNibbles( x[i].qs );
2485+
const __m256i bx = bytes_from_nibbles_32(x[i].qs);
24852486
const __m256i by = _mm256_loadu_si256( (const __m256i *)y[i].qs );
24862487

24872488
// Get absolute values of x vectors
@@ -2635,6 +2636,51 @@ static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void *
26352636
}
26362637

26372638
sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
2639+
#elif defined(__AVX2__)
2640+
// Initialize accumulator with zeros
2641+
__m256 acc = _mm256_setzero_ps();
2642+
2643+
// Main loop
2644+
for (int i = 0; i < nb; i++) {
2645+
/* Compute combined scale for the block */
2646+
const __m128 d0 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 0].d));
2647+
const __m128 d1 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 1].d));
2648+
const __m256 d = _mm256_mul_ps(_mm256_set_m128(d1, d0), _mm256_broadcast_ss(&y[i].d));
2649+
2650+
__m128i bx0 = bytes_from_nibbles_16(x[2*i + 0].qs);
2651+
__m128i bx1 = bytes_from_nibbles_16(x[2*i + 1].qs);
2652+
__m256i bx = _mm256_set_m128i(bx1, bx0);
2653+
2654+
// Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
2655+
const __m256i off = _mm256_set1_epi8(8);
2656+
bx = _mm256_sub_epi8(bx, off);
2657+
2658+
__m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
2659+
2660+
// Get absolute values of x vectors
2661+
const __m256i ax = _mm256_sign_epi8(bx, bx);
2662+
// Sign the values of the y vectors
2663+
const __m256i sy = _mm256_sign_epi8(by, bx);
2664+
// Perform multiplication and create 16-bit values
2665+
const __m256i dot = _mm256_maddubs_epi16(ax, sy);
2666+
2667+
const __m256i ones = _mm256_set1_epi16(1);
2668+
__m256i xy_q = _mm256_madd_epi16(ones, dot);
2669+
2670+
/* Convert to vectore of 8 int32_t to 8 floats */
2671+
__m256 q = _mm256_cvtepi32_ps(xy_q);
2672+
2673+
/* Multiply q with scale and accumulate */
2674+
acc = _mm256_fmadd_ps(d, q, acc);
2675+
}
2676+
2677+
// Return horizontal sum of the acc vector
2678+
__m128 res = _mm256_extractf128_ps(acc, 1);
2679+
res = _mm_add_ps(res, _mm256_castps256_ps128(acc));
2680+
res = _mm_add_ps(res, _mm_movehl_ps(res, res));
2681+
res = _mm_add_ss(res, _mm_movehdup_ps(res));
2682+
2683+
sumf = _mm_cvtss_f32(res);
26382684
#else
26392685
// scalar
26402686
for (int i = 0; i < nb; i++) {

0 commit comments

Comments
 (0)