@@ -474,6 +474,8 @@ static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float);
474
474
// quantization
475
475
//
476
476
477
+ #define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1)
478
+
477
479
#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__)
478
480
// multiply int8_t, add results pairwise twice
479
481
static inline __m128i mul_sum_i8_pairs(const __m128i x, const __m128i y) {
@@ -533,7 +535,7 @@ static inline __m256i bytes_from_bits_32(const uint8_t * x) {
533
535
static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi)
534
536
{
535
537
const __m128i tmp = _mm_loadu_si128((const __m128i *)rsi);
536
- const __m256i bytes = _mm256_set_m128i (_mm_srli_epi16(tmp, 4), tmp);
538
+ const __m256i bytes = MM256_SET_M128I (_mm_srli_epi16(tmp, 4), tmp);
537
539
const __m256i lowMask = _mm256_set1_epi8( 0xF );
538
540
return _mm256_and_si256(lowMask, bytes);
539
541
}
@@ -606,7 +608,7 @@ static inline __m256i bytes_from_bits_32(const uint8_t * x) {
606
608
bytesh = _mm_or_si128(bytesh, bit_mask);
607
609
bytesl = _mm_cmpeq_epi8(bytesl, _mm_set1_epi64x(-1));
608
610
bytesh = _mm_cmpeq_epi8(bytesh, _mm_set1_epi64x(-1));
609
- return _mm256_set_m128i (bytesh, bytesl);
611
+ return MM256_SET_M128I (bytesh, bytesl);
610
612
}
611
613
612
614
// Unpack 32 4-bit fields into 32 bytes
@@ -619,15 +621,15 @@ static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi)
619
621
const __m128i lowMask = _mm_set1_epi8(0xF);
620
622
tmpl = _mm_and_si128(lowMask, tmpl);
621
623
tmph = _mm_and_si128(lowMask, tmph);
622
- return _mm256_set_m128i (tmph, tmpl);
624
+ return MM256_SET_M128I (tmph, tmpl);
623
625
}
624
626
625
627
// add int16_t pairwise and return as float vector
626
628
static inline __m256 sum_i16_pairs_float(const __m128i xh, const __m128i xl) {
627
629
const __m128i ones = _mm_set1_epi16(1);
628
630
const __m128i summed_pairsl = _mm_madd_epi16(ones, xl);
629
631
const __m128i summed_pairsh = _mm_madd_epi16(ones, xh);
630
- const __m256i summed_pairs = _mm256_set_m128i (summed_pairsh, summed_pairsl);
632
+ const __m256i summed_pairs = MM256_SET_M128I (summed_pairsh, summed_pairsl);
631
633
return _mm256_cvtepi32_ps(summed_pairs);
632
634
}
633
635
@@ -2290,7 +2292,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
2290
2292
const __m128i i32_1 = mul_sum_i8_pairs(bx, by);
2291
2293
2292
2294
// Convert int32_t to float
2293
- __m256 p = _mm256_cvtepi32_ps(_mm256_set_m128i (i32_0, i32_1));
2295
+ __m256 p = _mm256_cvtepi32_ps(MM256_SET_M128I (i32_0, i32_1));
2294
2296
2295
2297
// Apply the scale, and accumulate
2296
2298
acc = _mm256_add_ps(_mm256_mul_ps( d, p ), acc);
@@ -2766,7 +2768,7 @@ static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void *
2766
2768
__m128i bxh = _mm256_extractf128_si256(bx, 1);
2767
2769
bxl = _mm_or_si128(bxl, bxhil);
2768
2770
bxh = _mm_or_si128(bxh, bxhih);
2769
- bx = _mm256_set_m128i (bxh, bxl);
2771
+ bx = MM256_SET_M128I (bxh, bxl);
2770
2772
2771
2773
const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
2772
2774
@@ -3022,7 +3024,7 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void *
3022
3024
__m128i bxh = _mm256_extractf128_si256(bx, 1);
3023
3025
bxl = _mm_or_si128(bxl, bxhil);
3024
3026
bxh = _mm_or_si128(bxh, bxhih);
3025
- bx = _mm256_set_m128i (bxh, bxl);
3027
+ bx = MM256_SET_M128I (bxh, bxl);
3026
3028
3027
3029
const __m256 dy = _mm256_set1_ps(y[i].d);
3028
3030
const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
0 commit comments