@@ -1666,6 +1666,62 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
1666
1666
1667
1667
* s = hsum_float_8 (acc ) + summs ;
1668
1668
1669
+ #elif defined __AVX__
1670
+
1671
+ const __m128i m3 = _mm_set1_epi8 (3 );
1672
+
1673
+ __m256 acc = _mm256_setzero_ps ();
1674
+
1675
+ uint32_t ud , um ;
1676
+ const uint8_t * restrict db = (const uint8_t * )& ud ;
1677
+ const uint8_t * restrict mb = (const uint8_t * )& um ;
1678
+
1679
+ float summs = 0 ;
1680
+
1681
+ // TODO: optimize this
1682
+
1683
+ for (int i = 0 ; i < nb ; ++ i ) {
1684
+
1685
+ const float d = y [i ].d * ggml_fp16_to_fp32 (x [i ].d );
1686
+ const float dmin = - y [i ].d * ggml_fp16_to_fp32 (x [i ].dmin );
1687
+
1688
+ const uint8_t * restrict q2 = x [i ].qs ;
1689
+ const int8_t * restrict q8 = y [i ].qs ;
1690
+
1691
+ const uint32_t * restrict sc = (const uint32_t * )x [i ].scales ;
1692
+ ud = (sc [0 ] >> 0 ) & 0x0f0f0f0f ;
1693
+ um = (sc [0 ] >> 4 ) & 0x0f0f0f0f ;
1694
+
1695
+ int32_t smin = mb [0 ] * y [i ].bsums [0 ] + mb [1 ] * y [i ].bsums [1 ] + mb [2 ] * y [i ].bsums [2 ] + mb [3 ] * y [i ].bsums [3 ];
1696
+ summs += dmin * smin ;
1697
+
1698
+ const __m128i q2bits = _mm_loadu_si128 ((const __m128i * )q2 );
1699
+ const __m128i q2_0 = _mm_and_si128 (q2bits , m3 );
1700
+ const __m128i q2_1 = _mm_and_si128 (_mm_srli_epi16 (q2bits , 2 ), m3 );
1701
+ const __m128i q2_2 = _mm_and_si128 (_mm_srli_epi16 (q2bits , 4 ), m3 );
1702
+ const __m128i q2_3 = _mm_and_si128 (_mm_srli_epi16 (q2bits , 6 ), m3 );
1703
+
1704
+ const __m256i q8_0 = _mm256_loadu_si256 ((const __m256i * )(q8 + 0 ));
1705
+ const __m256i q8_1 = _mm256_loadu_si256 ((const __m256i * )(q8 + 32 ));
1706
+
1707
+ const __m128i p0 = _mm_maddubs_epi16 (q2_0 , _mm256_extractf128_si256 (q8_0 , 0 ));
1708
+ const __m128i p1 = _mm_maddubs_epi16 (q2_1 , _mm256_extractf128_si256 (q8_0 , 1 ));
1709
+ const __m128i p2 = _mm_maddubs_epi16 (q2_2 , _mm256_extractf128_si256 (q8_1 , 0 ));
1710
+ const __m128i p3 = _mm_maddubs_epi16 (q2_3 , _mm256_extractf128_si256 (q8_1 , 1 ));
1711
+
1712
+ const __m256i p_0 = _mm256_set_m128i (_mm_cvtepi16_epi32 (_mm_unpackhi_epi64 (p0 , p0 )), _mm_cvtepi16_epi32 (p0 ));
1713
+ const __m256i p_1 = _mm256_set_m128i (_mm_cvtepi16_epi32 (_mm_unpackhi_epi64 (p1 , p1 )), _mm_cvtepi16_epi32 (p1 ));
1714
+ const __m256i p_2 = _mm256_set_m128i (_mm_cvtepi16_epi32 (_mm_unpackhi_epi64 (p2 , p2 )), _mm_cvtepi16_epi32 (p2 ));
1715
+ const __m256i p_3 = _mm256_set_m128i (_mm_cvtepi16_epi32 (_mm_unpackhi_epi64 (p3 , p3 )), _mm_cvtepi16_epi32 (p3 ));
1716
+
1717
+ acc = _mm256_add_ps (_mm256_mul_ps (_mm256_set1_ps (d * db [0 ]), _mm256_cvtepi32_ps (p_0 )), acc );
1718
+ acc = _mm256_add_ps (_mm256_mul_ps (_mm256_set1_ps (d * db [1 ]), _mm256_cvtepi32_ps (p_1 )), acc );
1719
+ acc = _mm256_add_ps (_mm256_mul_ps (_mm256_set1_ps (d * db [2 ]), _mm256_cvtepi32_ps (p_2 )), acc );
1720
+ acc = _mm256_add_ps (_mm256_mul_ps (_mm256_set1_ps (d * db [3 ]), _mm256_cvtepi32_ps (p_3 )), acc );
1721
+ }
1722
+
1723
+ * s = hsum_float_8 (acc ) + summs ;
1724
+
1669
1725
#else
1670
1726
1671
1727
float sumf = 0 ;
@@ -2295,6 +2351,93 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
2295
2351
2296
2352
* s = hsum_float_8 (acc );
2297
2353
2354
+ #elif defined __AVX__
2355
+
2356
+ const __m128i m3 = _mm_set1_epi8 (3 );
2357
+ const __m128i m1 = _mm_set1_epi8 (1 );
2358
+
2359
+ __m256 acc = _mm256_setzero_ps ();
2360
+
2361
+ uint64_t aux64 ;
2362
+
2363
+ uint16_t aux16 [2 ];
2364
+ const int8_t * aux8 = (const int8_t * )aux16 ;
2365
+
2366
+ for (int i = 0 ; i < nb ; ++ i ) {
2367
+
2368
+ const float d = y [i ].d * ggml_fp16_to_fp32 (x [i ].d );
2369
+
2370
+ const uint8_t * restrict q3 = x [i ].qs ;
2371
+ const int8_t * restrict q8 = y [i ].qs ;
2372
+
2373
+ const uint16_t a = * (const uint16_t * )x [i ].scales ;
2374
+ aux16 [0 ] = a & 0x0f0f ;
2375
+ aux16 [1 ] = (a >> 4 ) & 0x0f0f ;
2376
+
2377
+ const __m128i scale_0 = _mm_set1_epi16 (aux8 [0 ] - 8 );
2378
+ const __m128i scale_1 = _mm_set1_epi16 (aux8 [2 ] - 8 );
2379
+ const __m128i scale_2 = _mm_set1_epi16 (aux8 [1 ] - 8 );
2380
+ const __m128i scale_3 = _mm_set1_epi16 (aux8 [3 ] - 8 );
2381
+
2382
+ memcpy (& aux64 , x [i ].hmask , 8 );
2383
+
2384
+ __m128i q3h_0 = _mm_set_epi64x (aux64 >> 1 , aux64 >> 0 );
2385
+ __m128i q3h_1 = _mm_srli_epi16 (q3h_0 , 2 );
2386
+ __m128i q3h_2 = _mm_srli_epi16 (q3h_0 , 4 );
2387
+ __m128i q3h_3 = _mm_srli_epi16 (q3h_0 , 6 );
2388
+ q3h_0 = _mm_slli_epi16 (_mm_andnot_si128 (q3h_0 , m1 ), 2 );
2389
+ q3h_1 = _mm_slli_epi16 (_mm_andnot_si128 (q3h_1 , m1 ), 2 );
2390
+ q3h_2 = _mm_slli_epi16 (_mm_andnot_si128 (q3h_2 , m1 ), 2 );
2391
+ q3h_3 = _mm_slli_epi16 (_mm_andnot_si128 (q3h_3 , m1 ), 2 );
2392
+
2393
+ // load low 2 bits
2394
+ const __m128i q3bits = _mm_loadu_si128 ((const __m128i * )q3 );
2395
+
2396
+ // prepare low and high bits
2397
+ const __m128i q3l_0 = _mm_and_si128 (q3bits , m3 );
2398
+ const __m128i q3l_1 = _mm_and_si128 (_mm_srli_epi16 (q3bits , 2 ), m3 );
2399
+ const __m128i q3l_2 = _mm_and_si128 (_mm_srli_epi16 (q3bits , 4 ), m3 );
2400
+ const __m128i q3l_3 = _mm_and_si128 (_mm_srli_epi16 (q3bits , 6 ), m3 );
2401
+
2402
+ // load Q8 quants
2403
+ const __m256i q8_0 = _mm256_loadu_si256 ((const __m256i * )(q8 + 0 ));
2404
+ const __m256i q8_1 = _mm256_loadu_si256 ((const __m256i * )(q8 + 32 ));
2405
+
2406
+ // Dot product: we multiply the 2 low bits and 1 high bit part separately, so we can use _mm_maddubs_epi16,
2407
+ // and then subtract. The high bit part has the 2 already subtracted (and so, it is zero if the high bit was not set,
2408
+ // and 2 if the high bit was set)
2409
+ const __m128i q8s_0 = _mm_maddubs_epi16 (q3h_0 , _mm256_extractf128_si256 (q8_0 , 0 ));
2410
+ const __m128i q8s_1 = _mm_maddubs_epi16 (q3h_1 , _mm256_extractf128_si256 (q8_0 , 1 ));
2411
+ const __m128i q8s_2 = _mm_maddubs_epi16 (q3h_2 , _mm256_extractf128_si256 (q8_1 , 0 ));
2412
+ const __m128i q8s_3 = _mm_maddubs_epi16 (q3h_3 , _mm256_extractf128_si256 (q8_1 , 1 ));
2413
+
2414
+ __m128i p16_0 = _mm_maddubs_epi16 (q3l_0 , _mm256_extractf128_si256 (q8_0 , 0 ));
2415
+ __m128i p16_1 = _mm_maddubs_epi16 (q3l_1 , _mm256_extractf128_si256 (q8_0 , 1 ));
2416
+ __m128i p16_2 = _mm_maddubs_epi16 (q3l_2 , _mm256_extractf128_si256 (q8_1 , 0 ));
2417
+ __m128i p16_3 = _mm_maddubs_epi16 (q3l_3 , _mm256_extractf128_si256 (q8_1 , 1 ));
2418
+
2419
+ p16_0 = _mm_sub_epi16 (p16_0 , q8s_0 );
2420
+ p16_1 = _mm_sub_epi16 (p16_1 , q8s_1 );
2421
+ p16_2 = _mm_sub_epi16 (p16_2 , q8s_2 );
2422
+ p16_3 = _mm_sub_epi16 (p16_3 , q8s_3 );
2423
+
2424
+ // multiply with scales
2425
+ p16_0 = _mm_madd_epi16 (scale_0 , p16_0 );
2426
+ p16_1 = _mm_madd_epi16 (scale_1 , p16_1 );
2427
+ p16_2 = _mm_madd_epi16 (scale_2 , p16_2 );
2428
+ p16_3 = _mm_madd_epi16 (scale_3 , p16_3 );
2429
+
2430
+ p16_0 = _mm_add_epi32 (p16_0 , p16_2 );
2431
+ p16_1 = _mm_add_epi32 (p16_1 , p16_3 );
2432
+ __m256i p16 = _mm256_set_m128i (p16_1 , p16_0 );
2433
+
2434
+ // multiply with block scale and accumulate
2435
+ acc = _mm256_add_ps (_mm256_mul_ps (_mm256_broadcast_ss (& d ), _mm256_cvtepi32_ps (p16 )), acc );
2436
+
2437
+ }
2438
+
2439
+ * s = hsum_float_8 (acc );
2440
+
2298
2441
#else
2299
2442
2300
2443
int8_t aux8 [QK_K ];
@@ -2781,6 +2924,60 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
2781
2924
2782
2925
* s = hsum_float_8 (acc ) - summs ;
2783
2926
2927
+ #elif defined __AVX__
2928
+
2929
+ const __m128i m4 = _mm_set1_epi8 (0xF );
2930
+
2931
+ __m256 acc = _mm256_setzero_ps ();
2932
+
2933
+ float summs = 0 ;
2934
+
2935
+ uint16_t aux16 [2 ];
2936
+ const uint8_t * scales = (const uint8_t * )aux16 ;
2937
+
2938
+ for (int i = 0 ; i < nb ; ++ i ) {
2939
+
2940
+ const float d = ggml_fp16_to_fp32 (x [i ].d [0 ]) * y [i ].d ;
2941
+ const float m = ggml_fp16_to_fp32 (x [i ].d [1 ]) * y [i ].d ;
2942
+ const __m256 vd = _mm256_set1_ps (d );
2943
+
2944
+ const uint16_t * a = (const uint16_t * )x [i ].scales ;
2945
+ aux16 [0 ] = a [0 ] & 0x0f0f ;
2946
+ aux16 [1 ] = (a [0 ] >> 4 ) & 0x0f0f ;
2947
+
2948
+ summs += m * (scales [2 ] * (y [i ].bsums [0 ] + y [i ].bsums [1 ]) + scales [3 ] * (y [i ].bsums [2 ] + y [i ].bsums [3 ]));
2949
+
2950
+ const uint8_t * restrict q4 = x [i ].qs ;
2951
+ const int8_t * restrict q8 = y [i ].qs ;
2952
+
2953
+ const __m256i q4bits = _mm256_loadu_si256 ((const __m256i * )q4 );
2954
+ const __m128i q4bits_0 = _mm256_extractf128_si256 (q4bits , 0 );
2955
+ const __m128i q4bits_1 = _mm256_extractf128_si256 (q4bits , 1 );
2956
+ const __m128i q4_0 = _mm_and_si128 (q4bits_0 , m4 );
2957
+ const __m128i q4_1 = _mm_and_si128 (q4bits_1 , m4 );
2958
+ const __m128i q4_2 = _mm_and_si128 (_mm_srli_epi16 (q4bits_0 , 4 ), m4 );
2959
+ const __m128i q4_3 = _mm_and_si128 (_mm_srli_epi16 (q4bits_1 , 4 ), m4 );
2960
+
2961
+ const __m256i q8_0 = _mm256_loadu_si256 ((const __m256i * )(q8 + 0 ));
2962
+ const __m256i q8_1 = _mm256_loadu_si256 ((const __m256i * )(q8 + 32 ));
2963
+
2964
+ const __m128i p16_0 = _mm_maddubs_epi16 (q4_0 , _mm256_extractf128_si256 (q8_0 , 0 ));
2965
+ const __m128i p16_1 = _mm_maddubs_epi16 (q4_1 , _mm256_extractf128_si256 (q8_0 , 1 ));
2966
+ const __m128i p16_2 = _mm_maddubs_epi16 (q4_2 , _mm256_extractf128_si256 (q8_1 , 0 ));
2967
+ const __m128i p16_3 = _mm_maddubs_epi16 (q4_3 , _mm256_extractf128_si256 (q8_1 , 1 ));
2968
+
2969
+ const __m128i p32_0 = _mm_madd_epi16 (_mm_set1_epi16 (scales [0 ]), p16_0 );
2970
+ const __m128i p32_1 = _mm_madd_epi16 (_mm_set1_epi16 (scales [0 ]), p16_1 );
2971
+ acc = _mm256_add_ps (_mm256_mul_ps (vd , _mm256_cvtepi32_ps (_mm256_set_m128i (p32_1 , p32_0 ))), acc );
2972
+
2973
+ const __m128i p32_2 = _mm_madd_epi16 (_mm_set1_epi16 (scales [1 ]), p16_2 );
2974
+ const __m128i p32_3 = _mm_madd_epi16 (_mm_set1_epi16 (scales [1 ]), p16_3 );
2975
+ acc = _mm256_add_ps (_mm256_mul_ps (vd , _mm256_cvtepi32_ps (_mm256_set_m128i (p32_3 , p32_2 ))), acc );
2976
+
2977
+ }
2978
+
2979
+ * s = hsum_float_8 (acc ) - summs ;
2980
+
2784
2981
#else
2785
2982
2786
2983
uint8_t aux8 [QK_K ];
@@ -3295,6 +3492,63 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
3295
3492
3296
3493
* s = hsum_float_8 (acc );
3297
3494
3495
+ #elif defined __AVX__
3496
+
3497
+ const __m128i m4 = _mm_set1_epi8 (0xF );
3498
+ const __m128i mone = _mm_set1_epi8 (1 );
3499
+
3500
+ __m256 acc = _mm256_setzero_ps ();
3501
+
3502
+ for (int i = 0 ; i < nb ; ++ i ) {
3503
+
3504
+ const uint8_t * restrict q5 = x [i ].qs ;
3505
+ const int8_t * restrict q8 = y [i ].qs ;
3506
+
3507
+ const float d = y [i ].d * ggml_fp16_to_fp32 (x [i ].d );
3508
+
3509
+ const __m256i q5bits = _mm256_loadu_si256 ((const __m256i * )q5 );
3510
+
3511
+ const __m128i scale_0 = _mm_set1_epi16 (x [i ].scales [0 ]);
3512
+ const __m128i scale_1 = _mm_set1_epi16 (x [i ].scales [1 ]);
3513
+ const __m128i scale_2 = _mm_set1_epi16 (x [i ].scales [2 ]);
3514
+ const __m128i scale_3 = _mm_set1_epi16 (x [i ].scales [3 ]);
3515
+
3516
+ int64_t aux64 ;
3517
+ memcpy (& aux64 , x [i ].qh , 8 );
3518
+ const __m128i haux128_0 = _mm_set_epi64x (aux64 >> 1 , aux64 );
3519
+ const __m128i haux128_1 = _mm_srli_epi16 (haux128_0 , 2 );
3520
+
3521
+ const __m128i q5h_0 = _mm_slli_epi16 (_mm_andnot_si128 (haux128_0 , mone ), 4 );
3522
+ const __m128i q5h_1 = _mm_slli_epi16 (_mm_andnot_si128 (haux128_1 , mone ), 4 );
3523
+ const __m128i q5h_2 = _mm_slli_epi16 (_mm_andnot_si128 (_mm_srli_epi16 (haux128_0 , 4 ), mone ), 4 );
3524
+ const __m128i q5h_3 = _mm_slli_epi16 (_mm_andnot_si128 (_mm_srli_epi16 (haux128_1 , 4 ), mone ), 4 );
3525
+
3526
+ const __m128i q5l_0 = _mm_and_si128 (_mm256_extractf128_si256 (q5bits , 0 ), m4 );
3527
+ const __m128i q5l_1 = _mm_and_si128 (_mm256_extractf128_si256 (q5bits , 1 ), m4 );
3528
+ const __m128i q5l_2 = _mm_and_si128 (_mm_srli_epi16 (_mm256_extractf128_si256 (q5bits , 0 ), 4 ), m4 );
3529
+ const __m128i q5l_3 = _mm_and_si128 (_mm_srli_epi16 (_mm256_extractf128_si256 (q5bits , 1 ), 4 ), m4 );
3530
+
3531
+ const __m256i q8_0 = _mm256_loadu_si256 ((const __m256i * )(q8 + 0 ));
3532
+ const __m256i q8_1 = _mm256_loadu_si256 ((const __m256i * )(q8 + 32 ));
3533
+
3534
+ const __m128i p16_0 = _mm_madd_epi16 (scale_0 , _mm_maddubs_epi16 (q5l_0 , _mm256_extractf128_si256 (q8_0 , 0 )));
3535
+ const __m128i p16_1 = _mm_madd_epi16 (scale_1 , _mm_maddubs_epi16 (q5l_1 , _mm256_extractf128_si256 (q8_0 , 1 )));
3536
+ const __m128i p16_2 = _mm_madd_epi16 (scale_2 , _mm_maddubs_epi16 (q5l_2 , _mm256_extractf128_si256 (q8_1 , 0 )));
3537
+ const __m128i p16_3 = _mm_madd_epi16 (scale_3 , _mm_maddubs_epi16 (q5l_3 , _mm256_extractf128_si256 (q8_1 , 1 )));
3538
+ const __m128i s16_0 = _mm_madd_epi16 (scale_0 , _mm_maddubs_epi16 (q5h_0 , _mm256_extractf128_si256 (q8_0 , 0 )));
3539
+ const __m128i s16_1 = _mm_madd_epi16 (scale_1 , _mm_maddubs_epi16 (q5h_1 , _mm256_extractf128_si256 (q8_0 , 1 )));
3540
+ const __m128i s16_2 = _mm_madd_epi16 (scale_2 , _mm_maddubs_epi16 (q5h_2 , _mm256_extractf128_si256 (q8_1 , 0 )));
3541
+ const __m128i s16_3 = _mm_madd_epi16 (scale_3 , _mm_maddubs_epi16 (q5h_3 , _mm256_extractf128_si256 (q8_1 , 1 )));
3542
+
3543
+ const __m128i dot_0 = _mm_sub_epi32 (_mm_add_epi32 (p16_0 , p16_2 ), _mm_add_epi32 (s16_0 , s16_2 ));
3544
+ const __m128i dot_1 = _mm_sub_epi32 (_mm_add_epi32 (p16_1 , p16_3 ), _mm_add_epi32 (s16_1 , s16_3 ));
3545
+
3546
+ acc = _mm256_add_ps (_mm256_mul_ps (_mm256_set1_ps (d ), _mm256_cvtepi32_ps (_mm256_set_m128i (dot_1 , dot_0 ))), acc );
3547
+
3548
+ }
3549
+
3550
+ * s = hsum_float_8 (acc );
3551
+
3298
3552
#else
3299
3553
3300
3554
int8_t aux8 [QK_K ];
@@ -3857,6 +4111,77 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
3857
4111
3858
4112
* s = hsum_float_8 (acc );
3859
4113
4114
+ #elif defined __AVX__
4115
+
4116
+ const __m128i m4 = _mm_set1_epi8 (0xF );
4117
+ const __m128i m2 = _mm_set1_epi8 (3 );
4118
+ const __m128i m32s = _mm_set1_epi8 (32 );
4119
+
4120
+ __m256 acc = _mm256_setzero_ps ();
4121
+
4122
+ for (int i = 0 ; i < nb ; ++ i ) {
4123
+
4124
+ const float d = y [i ].d * ggml_fp16_to_fp32 (x [i ].d );
4125
+
4126
+ const uint8_t * restrict q4 = x [i ].ql ;
4127
+ const uint8_t * restrict qh = x [i ].qh ;
4128
+ const int8_t * restrict q8 = y [i ].qs ;
4129
+
4130
+ const __m64 scales_1 = _mm_set1_pi8 (x [i ].scales [0 ]);
4131
+ const __m64 scales_2 = _mm_set1_pi8 (x [i ].scales [1 ]);
4132
+ const __m64 scales_3 = _mm_set1_pi8 (x [i ].scales [2 ]);
4133
+ const __m64 scales_4 = _mm_set1_pi8 (x [i ].scales [3 ]);
4134
+
4135
+ __m128i sumi_0 = _mm_setzero_si128 ();
4136
+ __m128i sumi_1 = _mm_setzero_si128 ();
4137
+
4138
+ const __m128i scale_0 = _mm_set_epi64 (scales_2 , scales_1 );
4139
+ const __m128i scale_1 = _mm_set_epi64 (scales_4 , scales_3 );
4140
+
4141
+ const __m256i q4bits1 = _mm256_loadu_si256 ((const __m256i * )q4 );
4142
+ const __m128i q4bitsH = _mm_loadu_si128 ((const __m128i * )qh );
4143
+
4144
+ const __m128i q4h_0 = _mm_slli_epi16 (_mm_and_si128 (q4bitsH , m2 ), 4 );
4145
+ const __m128i q4h_1 = _mm_slli_epi16 (_mm_and_si128 (_mm_srli_epi16 (q4bitsH , 2 ), m2 ), 4 );
4146
+ const __m128i q4h_2 = _mm_slli_epi16 (_mm_and_si128 (_mm_srli_epi16 (q4bitsH , 4 ), m2 ), 4 );
4147
+ const __m128i q4h_3 = _mm_slli_epi16 (_mm_and_si128 (_mm_srli_epi16 (q4bitsH , 6 ), m2 ), 4 );
4148
+
4149
+ const __m128i q4_0 = _mm_or_si128 (_mm_and_si128 (_mm256_extractf128_si256 (q4bits1 , 0 ), m4 ), q4h_0 );
4150
+ const __m128i q4_1 = _mm_or_si128 (_mm_and_si128 (_mm256_extractf128_si256 (q4bits1 , 1 ), m4 ), q4h_1 );
4151
+ const __m128i q4_2 = _mm_or_si128 (_mm_and_si128 (_mm_srli_epi16 (_mm256_extractf128_si256 (q4bits1 , 0 ), 4 ), m4 ), q4h_2 );
4152
+ const __m128i q4_3 = _mm_or_si128 (_mm_and_si128 (_mm_srli_epi16 (_mm256_extractf128_si256 (q4bits1 , 1 ), 4 ), m4 ), q4h_3 );
4153
+
4154
+ const __m256i q8_0 = _mm256_loadu_si256 ((const __m256i * )(q8 + 0 ));
4155
+ const __m256i q8_1 = _mm256_loadu_si256 ((const __m256i * )(q8 + 32 ));
4156
+
4157
+ __m128i q8s_0 = _mm_maddubs_epi16 (m32s , _mm256_extractf128_si256 (q8_0 , 0 ));
4158
+ __m128i q8s_1 = _mm_maddubs_epi16 (m32s , _mm256_extractf128_si256 (q8_0 , 1 ));
4159
+ __m128i q8s_2 = _mm_maddubs_epi16 (m32s , _mm256_extractf128_si256 (q8_1 , 0 ));
4160
+ __m128i q8s_3 = _mm_maddubs_epi16 (m32s , _mm256_extractf128_si256 (q8_1 , 1 ));
4161
+
4162
+ __m128i p16_0 = _mm_maddubs_epi16 (q4_0 , _mm256_extractf128_si256 (q8_0 , 0 ));
4163
+ __m128i p16_1 = _mm_maddubs_epi16 (q4_1 , _mm256_extractf128_si256 (q8_0 , 1 ));
4164
+ __m128i p16_2 = _mm_maddubs_epi16 (q4_2 , _mm256_extractf128_si256 (q8_1 , 0 ));
4165
+ __m128i p16_3 = _mm_maddubs_epi16 (q4_3 , _mm256_extractf128_si256 (q8_1 , 1 ));
4166
+
4167
+ p16_0 = _mm_sub_epi16 (p16_0 , q8s_0 );
4168
+ p16_1 = _mm_sub_epi16 (p16_1 , q8s_1 );
4169
+ p16_2 = _mm_sub_epi16 (p16_2 , q8s_2 );
4170
+ p16_3 = _mm_sub_epi16 (p16_3 , q8s_3 );
4171
+
4172
+ p16_0 = _mm_madd_epi16 (_mm_cvtepi8_epi16 (scale_0 ), p16_0 );
4173
+ p16_1 = _mm_madd_epi16 (_mm_cvtepi8_epi16 (_mm_unpackhi_epi64 (scale_0 , scale_0 )), p16_1 );
4174
+ p16_2 = _mm_madd_epi16 (_mm_cvtepi8_epi16 (scale_1 ), p16_2 );
4175
+ p16_3 = _mm_madd_epi16 (_mm_cvtepi8_epi16 (_mm_unpackhi_epi64 (scale_1 , scale_1 )), p16_3 );
4176
+
4177
+ sumi_0 = _mm_add_epi32 (sumi_0 , _mm_add_epi32 (p16_0 , p16_2 ));
4178
+ sumi_1 = _mm_add_epi32 (sumi_1 , _mm_add_epi32 (p16_1 , p16_3 ));
4179
+
4180
+ acc = _mm256_add_ps (_mm256_mul_ps (_mm256_broadcast_ss (& d ), _mm256_cvtepi32_ps (_mm256_set_m128i (sumi_1 , sumi_0 ))), acc );
4181
+ }
4182
+
4183
+ * s = hsum_float_8 (acc );
4184
+
3860
4185
#else
3861
4186
3862
4187
int8_t aux8 [QK_K ];
0 commit comments