Skip to content

Commit be2301b

Browse files
authored
k_quants : add AVX support to dot functions with QK_K as 64 (#2339)
* add AVX to ggml_vec_dot_q2_K_q8_K() * add AVX to ggml_vec_dot_q3_K_q8_K() * add AVX to ggml_vec_dot_q4_K_q8_K() * add AVX to ggml_vec_dot_q5_K_q8_K() * add AVX to ggml_vec_dot_q6_K_q8_K() * refactor AVX code in ggml_vec_dot_q6_K_q8_K()
1 parent 1aa18ef commit be2301b

File tree

1 file changed

+325
-0
lines changed

1 file changed

+325
-0
lines changed

k_quants.c

Lines changed: 325 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1666,6 +1666,62 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
16661666

16671667
*s = hsum_float_8(acc) + summs;
16681668

1669+
#elif defined __AVX__
1670+
1671+
const __m128i m3 = _mm_set1_epi8(3);
1672+
1673+
__m256 acc = _mm256_setzero_ps();
1674+
1675+
uint32_t ud, um;
1676+
const uint8_t * restrict db = (const uint8_t *)&ud;
1677+
const uint8_t * restrict mb = (const uint8_t *)&um;
1678+
1679+
float summs = 0;
1680+
1681+
// TODO: optimize this
1682+
1683+
for (int i = 0; i < nb; ++i) {
1684+
1685+
const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
1686+
const float dmin = -y[i].d * ggml_fp16_to_fp32(x[i].dmin);
1687+
1688+
const uint8_t * restrict q2 = x[i].qs;
1689+
const int8_t * restrict q8 = y[i].qs;
1690+
1691+
const uint32_t * restrict sc = (const uint32_t *)x[i].scales;
1692+
ud = (sc[0] >> 0) & 0x0f0f0f0f;
1693+
um = (sc[0] >> 4) & 0x0f0f0f0f;
1694+
1695+
int32_t smin = mb[0] * y[i].bsums[0] + mb[1] * y[i].bsums[1] + mb[2] * y[i].bsums[2] + mb[3] * y[i].bsums[3];
1696+
summs += dmin * smin;
1697+
1698+
const __m128i q2bits = _mm_loadu_si128((const __m128i*)q2);
1699+
const __m128i q2_0 = _mm_and_si128(q2bits, m3);
1700+
const __m128i q2_1 = _mm_and_si128(_mm_srli_epi16(q2bits, 2), m3);
1701+
const __m128i q2_2 = _mm_and_si128(_mm_srli_epi16(q2bits, 4), m3);
1702+
const __m128i q2_3 = _mm_and_si128(_mm_srli_epi16(q2bits, 6), m3);
1703+
1704+
const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0));
1705+
const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32));
1706+
1707+
const __m128i p0 = _mm_maddubs_epi16(q2_0, _mm256_extractf128_si256(q8_0, 0));
1708+
const __m128i p1 = _mm_maddubs_epi16(q2_1, _mm256_extractf128_si256(q8_0, 1));
1709+
const __m128i p2 = _mm_maddubs_epi16(q2_2, _mm256_extractf128_si256(q8_1, 0));
1710+
const __m128i p3 = _mm_maddubs_epi16(q2_3, _mm256_extractf128_si256(q8_1, 1));
1711+
1712+
const __m256i p_0 = _mm256_set_m128i(_mm_cvtepi16_epi32(_mm_unpackhi_epi64(p0, p0)), _mm_cvtepi16_epi32(p0));
1713+
const __m256i p_1 = _mm256_set_m128i(_mm_cvtepi16_epi32(_mm_unpackhi_epi64(p1, p1)), _mm_cvtepi16_epi32(p1));
1714+
const __m256i p_2 = _mm256_set_m128i(_mm_cvtepi16_epi32(_mm_unpackhi_epi64(p2, p2)), _mm_cvtepi16_epi32(p2));
1715+
const __m256i p_3 = _mm256_set_m128i(_mm_cvtepi16_epi32(_mm_unpackhi_epi64(p3, p3)), _mm_cvtepi16_epi32(p3));
1716+
1717+
acc = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d * db[0]), _mm256_cvtepi32_ps(p_0)), acc);
1718+
acc = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d * db[1]), _mm256_cvtepi32_ps(p_1)), acc);
1719+
acc = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d * db[2]), _mm256_cvtepi32_ps(p_2)), acc);
1720+
acc = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d * db[3]), _mm256_cvtepi32_ps(p_3)), acc);
1721+
}
1722+
1723+
*s = hsum_float_8(acc) + summs;
1724+
16691725
#else
16701726

16711727
float sumf = 0;
@@ -2295,6 +2351,93 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
22952351

22962352
*s = hsum_float_8(acc);
22972353

2354+
#elif defined __AVX__
2355+
2356+
const __m128i m3 = _mm_set1_epi8(3);
2357+
const __m128i m1 = _mm_set1_epi8(1);
2358+
2359+
__m256 acc = _mm256_setzero_ps();
2360+
2361+
uint64_t aux64;
2362+
2363+
uint16_t aux16[2];
2364+
const int8_t * aux8 = (const int8_t *)aux16;
2365+
2366+
for (int i = 0; i < nb; ++i) {
2367+
2368+
const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
2369+
2370+
const uint8_t * restrict q3 = x[i].qs;
2371+
const int8_t * restrict q8 = y[i].qs;
2372+
2373+
const uint16_t a = *(const uint16_t *)x[i].scales;
2374+
aux16[0] = a & 0x0f0f;
2375+
aux16[1] = (a >> 4) & 0x0f0f;
2376+
2377+
const __m128i scale_0 = _mm_set1_epi16(aux8[0] - 8);
2378+
const __m128i scale_1 = _mm_set1_epi16(aux8[2] - 8);
2379+
const __m128i scale_2 = _mm_set1_epi16(aux8[1] - 8);
2380+
const __m128i scale_3 = _mm_set1_epi16(aux8[3] - 8);
2381+
2382+
memcpy(&aux64, x[i].hmask, 8);
2383+
2384+
__m128i q3h_0 = _mm_set_epi64x(aux64 >> 1, aux64 >> 0);
2385+
__m128i q3h_1 = _mm_srli_epi16(q3h_0, 2);
2386+
__m128i q3h_2 = _mm_srli_epi16(q3h_0, 4);
2387+
__m128i q3h_3 = _mm_srli_epi16(q3h_0, 6);
2388+
q3h_0 = _mm_slli_epi16(_mm_andnot_si128(q3h_0, m1), 2);
2389+
q3h_1 = _mm_slli_epi16(_mm_andnot_si128(q3h_1, m1), 2);
2390+
q3h_2 = _mm_slli_epi16(_mm_andnot_si128(q3h_2, m1), 2);
2391+
q3h_3 = _mm_slli_epi16(_mm_andnot_si128(q3h_3, m1), 2);
2392+
2393+
// load low 2 bits
2394+
const __m128i q3bits = _mm_loadu_si128((const __m128i*)q3);
2395+
2396+
// prepare low and high bits
2397+
const __m128i q3l_0 = _mm_and_si128(q3bits, m3);
2398+
const __m128i q3l_1 = _mm_and_si128(_mm_srli_epi16(q3bits, 2), m3);
2399+
const __m128i q3l_2 = _mm_and_si128(_mm_srli_epi16(q3bits, 4), m3);
2400+
const __m128i q3l_3 = _mm_and_si128(_mm_srli_epi16(q3bits, 6), m3);
2401+
2402+
// load Q8 quants
2403+
const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0));
2404+
const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32));
2405+
2406+
// Dot product: we multiply the 2 low bits and 1 high bit part separately, so we can use _mm_maddubs_epi16,
2407+
// and then subtract. The high bit part has the 2 already subtracted (and so, it is zero if the high bit was not set,
2408+
// and 2 if the high bit was set)
2409+
const __m128i q8s_0 = _mm_maddubs_epi16(q3h_0, _mm256_extractf128_si256(q8_0, 0));
2410+
const __m128i q8s_1 = _mm_maddubs_epi16(q3h_1, _mm256_extractf128_si256(q8_0, 1));
2411+
const __m128i q8s_2 = _mm_maddubs_epi16(q3h_2, _mm256_extractf128_si256(q8_1, 0));
2412+
const __m128i q8s_3 = _mm_maddubs_epi16(q3h_3, _mm256_extractf128_si256(q8_1, 1));
2413+
2414+
__m128i p16_0 = _mm_maddubs_epi16(q3l_0, _mm256_extractf128_si256(q8_0, 0));
2415+
__m128i p16_1 = _mm_maddubs_epi16(q3l_1, _mm256_extractf128_si256(q8_0, 1));
2416+
__m128i p16_2 = _mm_maddubs_epi16(q3l_2, _mm256_extractf128_si256(q8_1, 0));
2417+
__m128i p16_3 = _mm_maddubs_epi16(q3l_3, _mm256_extractf128_si256(q8_1, 1));
2418+
2419+
p16_0 = _mm_sub_epi16(p16_0, q8s_0);
2420+
p16_1 = _mm_sub_epi16(p16_1, q8s_1);
2421+
p16_2 = _mm_sub_epi16(p16_2, q8s_2);
2422+
p16_3 = _mm_sub_epi16(p16_3, q8s_3);
2423+
2424+
// multiply with scales
2425+
p16_0 = _mm_madd_epi16(scale_0, p16_0);
2426+
p16_1 = _mm_madd_epi16(scale_1, p16_1);
2427+
p16_2 = _mm_madd_epi16(scale_2, p16_2);
2428+
p16_3 = _mm_madd_epi16(scale_3, p16_3);
2429+
2430+
p16_0 = _mm_add_epi32(p16_0, p16_2);
2431+
p16_1 = _mm_add_epi32(p16_1, p16_3);
2432+
__m256i p16 = _mm256_set_m128i(p16_1, p16_0);
2433+
2434+
// multiply with block scale and accumulate
2435+
acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(p16)), acc);
2436+
2437+
}
2438+
2439+
*s = hsum_float_8(acc);
2440+
22982441
#else
22992442

23002443
int8_t aux8[QK_K];
@@ -2781,6 +2924,60 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
27812924

27822925
*s = hsum_float_8(acc) - summs;
27832926

2927+
#elif defined __AVX__
2928+
2929+
const __m128i m4 = _mm_set1_epi8(0xF);
2930+
2931+
__m256 acc = _mm256_setzero_ps();
2932+
2933+
float summs = 0;
2934+
2935+
uint16_t aux16[2];
2936+
const uint8_t * scales = (const uint8_t *)aux16;
2937+
2938+
for (int i = 0; i < nb; ++i) {
2939+
2940+
const float d = ggml_fp16_to_fp32(x[i].d[0]) * y[i].d;
2941+
const float m = ggml_fp16_to_fp32(x[i].d[1]) * y[i].d;
2942+
const __m256 vd = _mm256_set1_ps(d);
2943+
2944+
const uint16_t * a = (const uint16_t *)x[i].scales;
2945+
aux16[0] = a[0] & 0x0f0f;
2946+
aux16[1] = (a[0] >> 4) & 0x0f0f;
2947+
2948+
summs += m * (scales[2] * (y[i].bsums[0] + y[i].bsums[1]) + scales[3] * (y[i].bsums[2] + y[i].bsums[3]));
2949+
2950+
const uint8_t * restrict q4 = x[i].qs;
2951+
const int8_t * restrict q8 = y[i].qs;
2952+
2953+
const __m256i q4bits = _mm256_loadu_si256((const __m256i*)q4);
2954+
const __m128i q4bits_0 = _mm256_extractf128_si256(q4bits, 0);
2955+
const __m128i q4bits_1 = _mm256_extractf128_si256(q4bits, 1);
2956+
const __m128i q4_0 = _mm_and_si128(q4bits_0, m4);
2957+
const __m128i q4_1 = _mm_and_si128(q4bits_1, m4);
2958+
const __m128i q4_2 = _mm_and_si128(_mm_srli_epi16(q4bits_0, 4), m4);
2959+
const __m128i q4_3 = _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4);
2960+
2961+
const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0));
2962+
const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32));
2963+
2964+
const __m128i p16_0 = _mm_maddubs_epi16(q4_0, _mm256_extractf128_si256(q8_0, 0));
2965+
const __m128i p16_1 = _mm_maddubs_epi16(q4_1, _mm256_extractf128_si256(q8_0, 1));
2966+
const __m128i p16_2 = _mm_maddubs_epi16(q4_2, _mm256_extractf128_si256(q8_1, 0));
2967+
const __m128i p16_3 = _mm_maddubs_epi16(q4_3, _mm256_extractf128_si256(q8_1, 1));
2968+
2969+
const __m128i p32_0 = _mm_madd_epi16(_mm_set1_epi16(scales[0]), p16_0);
2970+
const __m128i p32_1 = _mm_madd_epi16(_mm_set1_epi16(scales[0]), p16_1);
2971+
acc = _mm256_add_ps(_mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_set_m128i(p32_1, p32_0))), acc);
2972+
2973+
const __m128i p32_2 = _mm_madd_epi16(_mm_set1_epi16(scales[1]), p16_2);
2974+
const __m128i p32_3 = _mm_madd_epi16(_mm_set1_epi16(scales[1]), p16_3);
2975+
acc = _mm256_add_ps(_mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_set_m128i(p32_3, p32_2))), acc);
2976+
2977+
}
2978+
2979+
*s = hsum_float_8(acc) - summs;
2980+
27842981
#else
27852982

27862983
uint8_t aux8[QK_K];
@@ -3295,6 +3492,63 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
32953492

32963493
*s = hsum_float_8(acc);
32973494

3495+
#elif defined __AVX__
3496+
3497+
const __m128i m4 = _mm_set1_epi8(0xF);
3498+
const __m128i mone = _mm_set1_epi8(1);
3499+
3500+
__m256 acc = _mm256_setzero_ps();
3501+
3502+
for (int i = 0; i < nb; ++i) {
3503+
3504+
const uint8_t * restrict q5 = x[i].qs;
3505+
const int8_t * restrict q8 = y[i].qs;
3506+
3507+
const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
3508+
3509+
const __m256i q5bits = _mm256_loadu_si256((const __m256i*)q5);
3510+
3511+
const __m128i scale_0 = _mm_set1_epi16(x[i].scales[0]);
3512+
const __m128i scale_1 = _mm_set1_epi16(x[i].scales[1]);
3513+
const __m128i scale_2 = _mm_set1_epi16(x[i].scales[2]);
3514+
const __m128i scale_3 = _mm_set1_epi16(x[i].scales[3]);
3515+
3516+
int64_t aux64;
3517+
memcpy(&aux64, x[i].qh, 8);
3518+
const __m128i haux128_0 = _mm_set_epi64x(aux64 >> 1, aux64);
3519+
const __m128i haux128_1 = _mm_srli_epi16(haux128_0, 2);
3520+
3521+
const __m128i q5h_0 = _mm_slli_epi16(_mm_andnot_si128(haux128_0, mone), 4);
3522+
const __m128i q5h_1 = _mm_slli_epi16(_mm_andnot_si128(haux128_1, mone), 4);
3523+
const __m128i q5h_2 = _mm_slli_epi16(_mm_andnot_si128(_mm_srli_epi16(haux128_0, 4), mone), 4);
3524+
const __m128i q5h_3 = _mm_slli_epi16(_mm_andnot_si128(_mm_srli_epi16(haux128_1, 4), mone), 4);
3525+
3526+
const __m128i q5l_0 = _mm_and_si128(_mm256_extractf128_si256(q5bits, 0), m4);
3527+
const __m128i q5l_1 = _mm_and_si128(_mm256_extractf128_si256(q5bits, 1), m4);
3528+
const __m128i q5l_2 = _mm_and_si128(_mm_srli_epi16(_mm256_extractf128_si256(q5bits, 0), 4), m4);
3529+
const __m128i q5l_3 = _mm_and_si128(_mm_srli_epi16(_mm256_extractf128_si256(q5bits, 1), 4), m4);
3530+
3531+
const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0));
3532+
const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32));
3533+
3534+
const __m128i p16_0 = _mm_madd_epi16(scale_0, _mm_maddubs_epi16(q5l_0, _mm256_extractf128_si256(q8_0, 0)));
3535+
const __m128i p16_1 = _mm_madd_epi16(scale_1, _mm_maddubs_epi16(q5l_1, _mm256_extractf128_si256(q8_0, 1)));
3536+
const __m128i p16_2 = _mm_madd_epi16(scale_2, _mm_maddubs_epi16(q5l_2, _mm256_extractf128_si256(q8_1, 0)));
3537+
const __m128i p16_3 = _mm_madd_epi16(scale_3, _mm_maddubs_epi16(q5l_3, _mm256_extractf128_si256(q8_1, 1)));
3538+
const __m128i s16_0 = _mm_madd_epi16(scale_0, _mm_maddubs_epi16(q5h_0, _mm256_extractf128_si256(q8_0, 0)));
3539+
const __m128i s16_1 = _mm_madd_epi16(scale_1, _mm_maddubs_epi16(q5h_1, _mm256_extractf128_si256(q8_0, 1)));
3540+
const __m128i s16_2 = _mm_madd_epi16(scale_2, _mm_maddubs_epi16(q5h_2, _mm256_extractf128_si256(q8_1, 0)));
3541+
const __m128i s16_3 = _mm_madd_epi16(scale_3, _mm_maddubs_epi16(q5h_3, _mm256_extractf128_si256(q8_1, 1)));
3542+
3543+
const __m128i dot_0 = _mm_sub_epi32(_mm_add_epi32(p16_0, p16_2), _mm_add_epi32(s16_0, s16_2));
3544+
const __m128i dot_1 = _mm_sub_epi32(_mm_add_epi32(p16_1, p16_3), _mm_add_epi32(s16_1, s16_3));
3545+
3546+
acc = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(_mm256_set_m128i(dot_1, dot_0))), acc);
3547+
3548+
}
3549+
3550+
*s = hsum_float_8(acc);
3551+
32983552
#else
32993553

33003554
int8_t aux8[QK_K];
@@ -3857,6 +4111,77 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
38574111

38584112
*s = hsum_float_8(acc);
38594113

4114+
#elif defined __AVX__
4115+
4116+
const __m128i m4 = _mm_set1_epi8(0xF);
4117+
const __m128i m2 = _mm_set1_epi8(3);
4118+
const __m128i m32s = _mm_set1_epi8(32);
4119+
4120+
__m256 acc = _mm256_setzero_ps();
4121+
4122+
for (int i = 0; i < nb; ++i) {
4123+
4124+
const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
4125+
4126+
const uint8_t * restrict q4 = x[i].ql;
4127+
const uint8_t * restrict qh = x[i].qh;
4128+
const int8_t * restrict q8 = y[i].qs;
4129+
4130+
const __m64 scales_1 = _mm_set1_pi8(x[i].scales[0]);
4131+
const __m64 scales_2 = _mm_set1_pi8(x[i].scales[1]);
4132+
const __m64 scales_3 = _mm_set1_pi8(x[i].scales[2]);
4133+
const __m64 scales_4 = _mm_set1_pi8(x[i].scales[3]);
4134+
4135+
__m128i sumi_0 = _mm_setzero_si128();
4136+
__m128i sumi_1 = _mm_setzero_si128();
4137+
4138+
const __m128i scale_0 = _mm_set_epi64(scales_2, scales_1);
4139+
const __m128i scale_1 = _mm_set_epi64(scales_4, scales_3);
4140+
4141+
const __m256i q4bits1 = _mm256_loadu_si256((const __m256i*)q4);
4142+
const __m128i q4bitsH = _mm_loadu_si128((const __m128i*)qh);
4143+
4144+
const __m128i q4h_0 = _mm_slli_epi16(_mm_and_si128(q4bitsH, m2), 4);
4145+
const __m128i q4h_1 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH, 2), m2), 4);
4146+
const __m128i q4h_2 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH, 4), m2), 4);
4147+
const __m128i q4h_3 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH, 6), m2), 4);
4148+
4149+
const __m128i q4_0 = _mm_or_si128(_mm_and_si128(_mm256_extractf128_si256(q4bits1, 0), m4), q4h_0);
4150+
const __m128i q4_1 = _mm_or_si128(_mm_and_si128(_mm256_extractf128_si256(q4bits1, 1), m4), q4h_1);
4151+
const __m128i q4_2 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(_mm256_extractf128_si256(q4bits1, 0), 4), m4), q4h_2);
4152+
const __m128i q4_3 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(_mm256_extractf128_si256(q4bits1, 1), 4), m4), q4h_3);
4153+
4154+
const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0));
4155+
const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32));
4156+
4157+
__m128i q8s_0 = _mm_maddubs_epi16(m32s, _mm256_extractf128_si256(q8_0, 0));
4158+
__m128i q8s_1 = _mm_maddubs_epi16(m32s, _mm256_extractf128_si256(q8_0, 1));
4159+
__m128i q8s_2 = _mm_maddubs_epi16(m32s, _mm256_extractf128_si256(q8_1, 0));
4160+
__m128i q8s_3 = _mm_maddubs_epi16(m32s, _mm256_extractf128_si256(q8_1, 1));
4161+
4162+
__m128i p16_0 = _mm_maddubs_epi16(q4_0, _mm256_extractf128_si256(q8_0, 0));
4163+
__m128i p16_1 = _mm_maddubs_epi16(q4_1, _mm256_extractf128_si256(q8_0, 1));
4164+
__m128i p16_2 = _mm_maddubs_epi16(q4_2, _mm256_extractf128_si256(q8_1, 0));
4165+
__m128i p16_3 = _mm_maddubs_epi16(q4_3, _mm256_extractf128_si256(q8_1, 1));
4166+
4167+
p16_0 = _mm_sub_epi16(p16_0, q8s_0);
4168+
p16_1 = _mm_sub_epi16(p16_1, q8s_1);
4169+
p16_2 = _mm_sub_epi16(p16_2, q8s_2);
4170+
p16_3 = _mm_sub_epi16(p16_3, q8s_3);
4171+
4172+
p16_0 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_0), p16_0);
4173+
p16_1 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_0, scale_0)), p16_1);
4174+
p16_2 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_1), p16_2);
4175+
p16_3 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_1, scale_1)), p16_3);
4176+
4177+
sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_0, p16_2));
4178+
sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_1, p16_3));
4179+
4180+
acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(_mm256_set_m128i(sumi_1, sumi_0))), acc);
4181+
}
4182+
4183+
*s = hsum_float_8(acc);
4184+
38604185
#else
38614186

38624187
int8_t aux8[QK_K];

0 commit comments

Comments
 (0)