Skip to content

Commit aa3a518

Browse files
committed
Performance improvement of AVX2 code
1 parent 02c5b27 commit aa3a518

File tree

1 file changed

+80
-30
lines changed

1 file changed

+80
-30
lines changed

ggml.c

Lines changed: 80 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1958,47 +1958,97 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
19581958
// Horizontal sum of all lanes of the accumulator
19591959
sumf = _mm512_reduce_add_ps( acc0 ) + _mm512_reduce_add_ps( acc1 );
19601960
#elif defined(__AVX2__)
1961+
// Input: 32 Nibbles (16 bytes) at *p0
1962+
// Output: 2 vectors with 16 values of type int16_t
1963+
#define EXPAND_32_Q4_NIBBLES_INTO_TWO_M256_VECTORS(OUT_HIGH,OUT_LOW,IN_SRC) \
1964+
/* get first input */ \
1965+
/* Load 16 bytes from memory */ \
1966+
const __m128i tmp_##OUT_HIGH = \
1967+
_mm_loadu_si128( (const __m128i_u *) IN_SRC); \
1968+
\
1969+
/* Expand bytes into uint16_t values */ \
1970+
const __m256i bytes_##OUT_HIGH = _mm256_cvtepu8_epi16(tmp_##OUT_HIGH); \
1971+
\
1972+
/* Unpack values into individual bytes */ \
1973+
const __m256i pre_shift_##OUT_HIGH = \
1974+
_mm256_andnot_si256( lowMask, bytes_##OUT_HIGH ); \
1975+
__m256i OUT_HIGH = _mm256_srli_epi16( pre_shift_##OUT_HIGH, 4 ); \
1976+
\
1977+
__m256i OUT_LOW = _mm256_and_si256( lowMask, bytes_##OUT_HIGH ); \
1978+
/* Now we have a vector with bytes in [ 0 .. 15 ] interval.
1979+
Offset them into [ -8 .. +7 ] interval. */ \
1980+
OUT_HIGH = _mm256_sub_epi16( OUT_HIGH, offset_8 ); \
1981+
OUT_LOW = _mm256_sub_epi16( OUT_LOW, offset_8 );
1982+
1983+
1984+
// Input: 32 Nibbles (16 bytes) at *p0
1985+
// Output: 2 vectors with 16 values of type int16_t
1986+
#define GET_SCALE_AND_QUANT_DOT_PRODUCT(SCALE, DOT, INDEX, OFFSET, ACC)\
1987+
/* Compute combined scale for the block */ \
1988+
const __m256 SCALE = _mm256_mul_ps( \
1989+
_mm256_broadcast_ss( &x[INDEX+OFFSET].d ), \
1990+
_mm256_broadcast_ss( &y[INDEX+OFFSET].d ) ); \
1991+
\
1992+
/* Compute the dot product of the quads*/ \
1993+
/* Input: 32 Nibbles (16 bytes) at *p0
1994+
Output: 2 vectors with 16 values of type int16_t */ \
1995+
EXPAND_32_Q4_NIBBLES_INTO_TWO_M256_VECTORS( \
1996+
x_high_##DOT, \
1997+
x_low_##DOT, \
1998+
x[INDEX+OFFSET].qs) \
1999+
\
2000+
/* Input: 32 Nibbles (16 bytes) at *p1
2001+
Output: 2 vectors with 16 values of type int16_t */ \
2002+
EXPAND_32_Q4_NIBBLES_INTO_TWO_M256_VECTORS( \
2003+
y_high_##DOT, \
2004+
y_low_##DOT, \
2005+
y[INDEX+OFFSET].qs) \
2006+
\
2007+
/* Compute products of int16_t integers, add pairwise */ \
2008+
__m256i x_y_high_##DOT = \
2009+
_mm256_madd_epi16( x_high_##DOT, y_high_##DOT ); \
2010+
\
2011+
__m256i x_y_low_##DOT = \
2012+
_mm256_madd_epi16( x_low_##DOT, y_low_##DOT ); \
2013+
\
2014+
/* Accumulate products of int16_t integers */ \
2015+
__m256i x_y_##DOT = _mm256_add_epi32( \
2016+
x_y_high_##DOT, \
2017+
x_y_low_##DOT ); \
2018+
\
2019+
/* Convert int32_t to float*/ \
2020+
__m256 DOT = _mm256_cvtepi32_ps( x_y_##DOT ); \
2021+
ACC = _mm256_fmadd_ps( SCALE, DOT, ACC );
2022+
2023+
2024+
#define UNROLL_COUNT 8
2025+
2026+
// make sure we only unroll multiples of the block count
2027+
assert(nb % UNROLL_COUNT == 0);
2028+
19612029
// Initialize accumulator with zeros
19622030
__m256 acc = _mm256_setzero_ps();
19632031

19642032
// Main loop
1965-
for (int i = 0; i < nb; ++i) {
1966-
// Compute combined scale for the block
1967-
const __m256 d = _mm256_mul_ps( _mm256_broadcast_ss( &x[i].d ), _mm256_broadcast_ss( &y[i].d ) );
2033+
for (int i = 0; i < nb; i+=UNROLL_COUNT) {
19682034

1969-
// Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes
1970-
__m256i bx = bytesFromNibbles( x[i].qs );
1971-
__m256i by = bytesFromNibbles( y[i].qs );
1972-
1973-
// Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
1974-
const __m256i off = _mm256_set1_epi8( 8 );
1975-
bx = _mm256_sub_epi8( bx, off );
1976-
by = _mm256_sub_epi8( by, off );
1977-
1978-
// Sign-extend first 16 signed bytes into int16_t
1979-
__m256i x16 = _mm256_cvtepi8_epi16( _mm256_castsi256_si128( bx ) );
1980-
__m256i y16 = _mm256_cvtepi8_epi16( _mm256_castsi256_si128( by ) );
1981-
// Compute products of int16_t integers, add pairwise
1982-
__m256i i32 = _mm256_madd_epi16( x16, y16 );
1983-
1984-
// Sign-extend last 16 signed bytes into int16_t vectors
1985-
x16 = _mm256_cvtepi8_epi16( _mm256_extracti128_si256( bx, 1 ) );
1986-
y16 = _mm256_cvtepi8_epi16( _mm256_extracti128_si256( by, 1 ) );
1987-
// Accumulate products of int16_t integers
1988-
i32 = _mm256_add_epi32( i32, _mm256_madd_epi16( x16, y16 ) );
1989-
1990-
// Convert int32_t to float
1991-
__m256 p = _mm256_cvtepi32_ps( i32 );
1992-
// Apply the scale, and accumulate
1993-
acc = _mm256_fmadd_ps( d, p, acc );
1994-
}
2035+
/* Load 16 bytes, and unpack 4 bit fields into bytes */
2036+
const __m256i lowMask = _mm256_set1_epi8( 0xF );
2037+
const __m256i offset_8 = _mm256_set1_epi16( 8 );
2038+
2039+
// This loop will be unrolled by the compiler
2040+
for (int u=0;u<UNROLL_COUNT;u++) {
2041+
GET_SCALE_AND_QUANT_DOT_PRODUCT(scale, q, i, u, acc);
2042+
}
2043+
2044+
}
19952045

19962046
// Return horizontal sum of the acc vector
19972047
__m128 res = _mm256_extractf128_ps( acc, 1 );
19982048
res = _mm_add_ps( res, _mm256_castps256_ps128( acc ) );
19992049
res = _mm_add_ps( res, _mm_movehl_ps( res, res ) );
20002050
res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
2001-
2051+
20022052
sumf = _mm_cvtss_f32( res );
20032053
#elif defined(__AVX__)
20042054
// Initialize accumulator with zeros

0 commit comments

Comments
 (0)