Skip to content

Commit 60f8c36

Browse files
authored
ggml : add AVX support based on AVX2 code (#1430)
1 parent 601a033 commit 60f8c36

File tree

1 file changed

+132
-3
lines changed

1 file changed

+132
-3
lines changed

ggml.c

Lines changed: 132 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -580,7 +580,63 @@ static inline __m128i packNibbles( __m256i bytes )
580580
return _mm_packus_epi16( r0, r1 );
581581
#endif
582582
}
583-
#else
583+
#elif defined(__AVX__)
584+
// spread 32 bits to 32 bytes { 0x00, 0xFF }
585+
static inline __m256i bytes_from_bits_32(const uint8_t * x) {
586+
uint32_t x32;
587+
memcpy(&x32, x, sizeof(uint32_t));
588+
const __m128i shuf_maskl = _mm_set_epi64x(0x0101010101010101, 0x0000000000000000);
589+
const __m128i shuf_maskh = _mm_set_epi64x(0x0303030303030303, 0x0202020202020202);
590+
__m128i bytesl = _mm_shuffle_epi8(_mm_set1_epi32(x32), shuf_maskl);
591+
__m128i bytesh = _mm_shuffle_epi8(_mm_set1_epi32(x32), shuf_maskh);
592+
const __m128i bit_mask = _mm_set1_epi64x(0x7fbfdfeff7fbfdfe);
593+
bytesl = _mm_or_si128(bytesl, bit_mask);
594+
bytesh = _mm_or_si128(bytesh, bit_mask);
595+
bytesl = _mm_cmpeq_epi8(bytesl, _mm_set1_epi64x(-1));
596+
bytesh = _mm_cmpeq_epi8(bytesh, _mm_set1_epi64x(-1));
597+
return _mm256_set_m128i(bytesh, bytesl);
598+
}
599+
600+
// Unpack 32 4-bit fields into 32 bytes
601+
// The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval
602+
static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi)
603+
{
604+
// Load 16 bytes from memory
605+
__m128i tmpl = _mm_loadu_si128((const __m128i *)rsi);
606+
__m128i tmph = _mm_srli_epi16(tmpl, 4);
607+
const __m128i lowMask = _mm_set1_epi8(0xF);
608+
tmpl = _mm_and_si128(lowMask, tmpl);
609+
tmph = _mm_and_si128(lowMask, tmph);
610+
return _mm256_set_m128i(tmph, tmpl);
611+
}
612+
613+
// add int16_t pairwise and return as float vector
614+
static inline __m256 sum_i16_pairs_float(const __m128i xh, const __m128i xl) {
615+
const __m128i ones = _mm_set1_epi16(1);
616+
const __m128i summed_pairsl = _mm_madd_epi16(ones, xl);
617+
const __m128i summed_pairsh = _mm_madd_epi16(ones, xh);
618+
const __m256i summed_pairs = _mm256_set_m128i(summed_pairsh, summed_pairsl);
619+
return _mm256_cvtepi32_ps(summed_pairs);
620+
}
621+
622+
// multiply int8_t, add results pairwise twice and return as float vector
623+
static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) {
624+
const __m128i xl = _mm256_castsi256_si128(x);
625+
const __m128i xh = _mm256_extractf128_si256(x, 1);
626+
const __m128i yl = _mm256_castsi256_si128(y);
627+
const __m128i yh = _mm256_extractf128_si256(y, 1);
628+
// Get absolute values of x vectors
629+
const __m128i axl = _mm_sign_epi8(xl, xl);
630+
const __m128i axh = _mm_sign_epi8(xh, xh);
631+
// Sign the values of the y vectors
632+
const __m128i syl = _mm_sign_epi8(yl, xl);
633+
const __m128i syh = _mm_sign_epi8(yh, xh);
634+
// Perform multiplication and create 16-bit values
635+
const __m128i dotl = _mm_maddubs_epi16(axl, syl);
636+
const __m128i doth = _mm_maddubs_epi16(axh, syh);
637+
return sum_i16_pairs_float(doth, dotl);
638+
}
639+
584640
static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 )
585641
{
586642
// Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh
@@ -2355,7 +2411,7 @@ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void *
23552411
}
23562412

23572413
*s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs;
2358-
#elif defined(__AVX2__)
2414+
#elif defined(__AVX2__) || defined(__AVX__)
23592415
// Initialize accumulator with zeros
23602416
__m256 acc = _mm256_setzero_ps();
23612417

@@ -2381,7 +2437,11 @@ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void *
23812437
const __m256 xy = mul_sum_i8_pairs_float(bx, by);
23822438

23832439
// Accumulate d0*d1*x*y
2440+
#if defined(__AVX2__)
23842441
acc = _mm256_fmadd_ps( d0d1, xy, acc );
2442+
#else
2443+
acc = _mm256_add_ps( _mm256_mul_ps( d0d1, xy ), acc );
2444+
#endif
23852445
}
23862446

23872447
*s = hsum_float_8(acc) + summs;
@@ -2592,6 +2652,37 @@ static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void *
25922652
acc = _mm256_fmadd_ps(d, q, acc);
25932653
}
25942654

2655+
*s = hsum_float_8(acc);
2656+
#elif defined(__AVX__)
2657+
// Initialize accumulator with zeros
2658+
__m256 acc = _mm256_setzero_ps();
2659+
__m128i mask = _mm_set1_epi8((char)0xF0);
2660+
2661+
// Main loop
2662+
for (int i = 0; i < nb; i++) {
2663+
/* Compute combined scale for the block */
2664+
const __m256 d = _mm256_mul_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d)), _mm256_broadcast_ss(&y[i].d));
2665+
2666+
__m256i bx = bytes_from_nibbles_32(x[i].qs);
2667+
const __m256i bxhi = bytes_from_bits_32(x[i].qh);
2668+
__m128i bxhil = _mm256_castsi256_si128(bxhi);
2669+
__m128i bxhih = _mm256_extractf128_si256(bxhi, 1);
2670+
bxhil = _mm_andnot_si128(bxhil, mask);
2671+
bxhih = _mm_andnot_si128(bxhih, mask);
2672+
__m128i bxl = _mm256_castsi256_si128(bx);
2673+
__m128i bxh = _mm256_extractf128_si256(bx, 1);
2674+
bxl = _mm_or_si128(bxl, bxhil);
2675+
bxh = _mm_or_si128(bxh, bxhih);
2676+
bx = _mm256_set_m128i(bxh, bxl);
2677+
2678+
const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
2679+
2680+
const __m256 q = mul_sum_i8_pairs_float(bx, by);
2681+
2682+
/* Multiply q with scale and accumulate */
2683+
acc = _mm256_add_ps(_mm256_mul_ps(d, q), acc);
2684+
}
2685+
25952686
*s = hsum_float_8(acc);
25962687
#else
25972688
// scalar
@@ -2820,6 +2911,40 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void *
28202911
acc = _mm256_fmadd_ps(q, _mm256_mul_ps(dx, dy), acc);
28212912
}
28222913

2914+
*s = hsum_float_8(acc) + summs;
2915+
#elif defined(__AVX__)
2916+
// Initialize accumulator with zeros
2917+
__m256 acc = _mm256_setzero_ps();
2918+
__m128i mask = _mm_set1_epi8(0x10);
2919+
2920+
float summs = 0.0f;
2921+
2922+
// Main loop
2923+
for (int i = 0; i < nb; i++) {
2924+
const __m256 dx = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d));
2925+
2926+
summs += GGML_FP16_TO_FP32(x[i].m) * y[i].s;
2927+
2928+
__m256i bx = bytes_from_nibbles_32(x[i].qs);
2929+
const __m256i bxhi = bytes_from_bits_32(x[i].qh);
2930+
__m128i bxhil = _mm256_castsi256_si128(bxhi);
2931+
__m128i bxhih = _mm256_extractf128_si256(bxhi, 1);
2932+
bxhil = _mm_and_si128(bxhil, mask);
2933+
bxhih = _mm_and_si128(bxhih, mask);
2934+
__m128i bxl = _mm256_castsi256_si128(bx);
2935+
__m128i bxh = _mm256_extractf128_si256(bx, 1);
2936+
bxl = _mm_or_si128(bxl, bxhil);
2937+
bxh = _mm_or_si128(bxh, bxhih);
2938+
bx = _mm256_set_m128i(bxh, bxl);
2939+
2940+
const __m256 dy = _mm256_broadcast_ss(&y[i].d);
2941+
const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
2942+
2943+
const __m256 q = mul_sum_i8_pairs_float(bx, by);
2944+
2945+
acc = _mm256_add_ps(_mm256_mul_ps(q, _mm256_mul_ps(dx, dy)), acc);
2946+
}
2947+
28232948
*s = hsum_float_8(acc) + summs;
28242949
#else
28252950
// scalar
@@ -2910,7 +3035,7 @@ static void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void *
29103035
}
29113036

29123037
*s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
2913-
#elif defined(__AVX2__)
3038+
#elif defined(__AVX2__) || defined(__AVX__)
29143039
// Initialize accumulator with zeros
29153040
__m256 acc = _mm256_setzero_ps();
29163041

@@ -2924,7 +3049,11 @@ static void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void *
29243049
const __m256 q = mul_sum_i8_pairs_float(bx, by);
29253050

29263051
// Multiply q with scale and accumulate
3052+
#if defined(__AVX2__)
29273053
acc = _mm256_fmadd_ps( d, q, acc );
3054+
#else
3055+
acc = _mm256_add_ps( _mm256_mul_ps( d, q ), acc );
3056+
#endif
29283057
}
29293058

29303059
*s = hsum_float_8(acc);

0 commit comments

Comments
 (0)