Skip to content

AVX2 optimization for vec_dot_q4_2_q8_0 #1068

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 20, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 73 additions & 26 deletions ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -463,12 +463,30 @@ static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float);
// quantization
//

// AVX routines provided by GH user Const-me
// ref: https://github.com/ggerganov/ggml/pull/27#issuecomment-1464934600
#if __AVX__ || __AVX2__ || __AVX512F__
// Unpack 16 4-bit fields into 16 bytes
// The output vector contains 16 bytes, each one in [ 0 .. 15 ] interval
static inline __m128i bytes_from_nibbles_16(const uint8_t * rsi)
{
// Load 8 bytes from memory
__m128i tmp = _mm_loadu_si64( ( const __m128i* )rsi );

// Expand bytes into uint16_t values
__m128i bytes = _mm_cvtepu8_epi16( tmp );

// Unpack values into individual bytes
const __m128i lowMask = _mm_set1_epi8( 0xF );
__m128i high = _mm_andnot_si128( lowMask, bytes );
__m128i low = _mm_and_si128( lowMask, bytes );
high = _mm_slli_epi16( high, 4 );
bytes = _mm_or_si128( low, high );
return bytes;
}

#if __AVX2__ || __AVX512F__
// Unpack 32 4-bit fields into 32 bytes
// The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval
static inline __m256i bytesFromNibbles( const uint8_t* rsi )
static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi)
{
// Load 16 bytes from memory
__m128i tmp = _mm_loadu_si128( ( const __m128i* )rsi );
Expand Down Expand Up @@ -499,24 +517,7 @@ static inline __m128i packNibbles( __m256i bytes )
__m128i r1 = _mm256_extracti128_si256( bytes, 1 );
return _mm_packus_epi16( r0, r1 );
}
#elif __AVX__
static inline __m128i bytesFromNibbles( const uint8_t* rsi )
{
// Load 8 bytes from memory
__m128i tmp = _mm_loadu_si64( ( const __m128i* )rsi );

// Expand bytes into uint16_t values
__m128i bytes = _mm_cvtepu8_epi16( tmp );

// Unpack values into individual bytes
const __m128i lowMask = _mm_set1_epi8( 0xF );
__m128i high = _mm_andnot_si128( lowMask, bytes );
__m128i low = _mm_and_si128( lowMask, bytes );
high = _mm_slli_epi16( high, 4 );
bytes = _mm_or_si128( low, high );
return bytes;
}

#else
static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 )
{
// Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh
Expand All @@ -533,6 +534,7 @@ static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 )
return _mm_packus_epi16( bytes1, bytes2);
}
#endif
#endif // __AVX__ || __AVX2__ || __AVX512F__

#if __ARM_NEON

Expand Down Expand Up @@ -1309,7 +1311,7 @@ static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, in

for (int l = 0; l < QK4_0; l += 32) {
// Load 32x4-bit integers into 32x8-bit integers
__m256i vx8 = bytesFromNibbles(pp+l/2);
__m256i vx8 = bytes_from_nibbles_32(pp+l/2);

// Subtract 8 from the integers
vx8 = _mm256_sub_epi8(vx8, _mm256_set1_epi8(8));
Expand Down Expand Up @@ -1427,7 +1429,7 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in

for (int l = 0; l < QK4_1; l += 32) {
// Load 32x4-bit integers into 32x8-bit integers
__m256i vx8 = bytesFromNibbles(pp+l/2);
__m256i vx8 = bytes_from_nibbles_32(pp+l/2);

// Convert to 16-bit int
const __m256i vx16_lo = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(vx8, 0));
Expand Down Expand Up @@ -2270,7 +2272,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
/* Compute combined scale for the block */
const __m256 d = _mm256_mul_ps( _mm256_broadcast_ss( &x[i].d ), _mm256_broadcast_ss( &y[i].d ) );

__m256i bx = bytesFromNibbles(x[i].qs);
__m256i bx = bytes_from_nibbles_32(x[i].qs);

// Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
const __m256i off = _mm256_set1_epi8( 8 );
Expand Down Expand Up @@ -2316,7 +2318,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
__m128i i32[2];
for (int j = 0; j < 2; ++j) {
// Load 8 bytes, and unpack 4 bit fields into bytes, making 16 bytes
__m128i bx = bytesFromNibbles( x[i].qs + 8*j );
__m128i bx = bytes_from_nibbles_16(x[i].qs + 8*j);
__m128i by = _mm_loadu_si128((const __m128i *)(y[i].qs + 16*j));

// Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
Expand Down Expand Up @@ -2481,7 +2483,7 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
const __m256 d1m0 = _mm256_mul_ps( d1v, m0v );

// Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes
const __m256i bx = bytesFromNibbles( x[i].qs );
const __m256i bx = bytes_from_nibbles_32(x[i].qs);
const __m256i by = _mm256_loadu_si256( (const __m256i *)y[i].qs );

// Get absolute values of x vectors
Expand Down Expand Up @@ -2635,6 +2637,51 @@ static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void *
}

sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
#elif defined(__AVX2__)
// Initialize accumulator with zeros
__m256 acc = _mm256_setzero_ps();

// Main loop
for (int i = 0; i < nb; i++) {
/* Compute combined scale for the block */
const __m128 d0 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 0].d));
const __m128 d1 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 1].d));
const __m256 d = _mm256_mul_ps(_mm256_set_m128(d1, d0), _mm256_broadcast_ss(&y[i].d));

__m128i bx0 = bytes_from_nibbles_16(x[2*i + 0].qs);
__m128i bx1 = bytes_from_nibbles_16(x[2*i + 1].qs);
__m256i bx = _mm256_set_m128i(bx1, bx0);

// Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
const __m256i off = _mm256_set1_epi8(8);
bx = _mm256_sub_epi8(bx, off);

__m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);

// Get absolute values of x vectors
const __m256i ax = _mm256_sign_epi8(bx, bx);
// Sign the values of the y vectors
const __m256i sy = _mm256_sign_epi8(by, bx);
// Perform multiplication and create 16-bit values
const __m256i dot = _mm256_maddubs_epi16(ax, sy);

const __m256i ones = _mm256_set1_epi16(1);
__m256i xy_q = _mm256_madd_epi16(ones, dot);

/* Convert to vectore of 8 int32_t to 8 floats */
__m256 q = _mm256_cvtepi32_ps(xy_q);

/* Multiply q with scale and accumulate */
acc = _mm256_fmadd_ps(d, q, acc);
}

// Return horizontal sum of the acc vector
__m128 res = _mm256_extractf128_ps(acc, 1);
res = _mm_add_ps(res, _mm256_castps256_ps128(acc));
res = _mm_add_ps(res, _mm_movehl_ps(res, res));
res = _mm_add_ss(res, _mm_movehdup_ps(res));

sumf = _mm_cvtss_f32(res);
#else
// scalar
for (int i = 0; i < nb; i++) {
Expand Down