@@ -2539,19 +2539,20 @@ static void ggml_vec_dot_q2_0_q8_0(const int n, float * restrict s, const void *
2539
2539
// Initialize accumulator with zeros
2540
2540
__m256 acc = _mm256_setzero_ps ();
2541
2541
2542
- for (int i = 0 ; i < nb ; i += 2 ) {
2543
- __m256i bx = bytes_from_crumbs (x [i + 1 ].qs , x [i ].qs );
2542
+ for (int i = 0 ; i < nb / 2 ; i ++ ) {
2543
+ __m256i bx = bytes_from_crumbs (x [i * 2 + 1 ].qs , x [i * 2 ].qs );
2544
2544
2545
2545
// Compute combined scale for the block
2546
- const __m128 scale_lo = _mm_set1_ps (GGML_FP16_TO_FP32 (x [i + 0 ].d ) * y [i /2 ].d );
2547
- const __m128 scale_hi = _mm_set1_ps (GGML_FP16_TO_FP32 (x [i + 1 ].d ) * y [i /2 ].d );
2548
- const __m256 scale = _mm256_set_m128 (scale_hi , scale_lo );
2546
+ const __m128 scale_lo = _mm_set1_ps (GGML_FP16_TO_FP32 (x [i * 2 + 0 ].d ));
2547
+ const __m128 scale_hi = _mm_set1_ps (GGML_FP16_TO_FP32 (x [i * 2 + 1 ].d ));
2548
+ __m256 scale = _mm256_set_m128 (scale_hi , scale_lo );
2549
+ scale = _mm256_mul_ps (scale , _mm256_broadcast_ss (& y [i ].d ));
2549
2550
2550
2551
const __m256i off = _mm256_set1_epi8 (2 );
2551
2552
bx = _mm256_sub_epi8 (bx , off );
2552
2553
2553
2554
// Load y vector
2554
- const __m256i by = _mm256_loadu_si256 ((const __m256i * )y [i / 2 ].qs );
2555
+ const __m256i by = _mm256_loadu_si256 ((const __m256i * )y [i ].qs );
2555
2556
2556
2557
// Get absolute values of x vectors
2557
2558
const __m256i ax = _mm256_sign_epi8 (bx , bx );
@@ -2604,6 +2605,7 @@ static void ggml_vec_dot_q2_0_q8_0(const int n, float * restrict s, const void *
2604
2605
static void ggml_vec_dot_q3_0_q8_0 (const int n , float * restrict s , const void * restrict vx , const void * restrict vy ) {
2605
2606
assert (n % QK3_0 == 0 );
2606
2607
const int nb = n / QK3_0 ;
2608
+ assert (nb % 2 == 0 );
2607
2609
2608
2610
const block_q3_0 * restrict x = vx ;
2609
2611
const block_q8_0 * restrict y = vy ;
@@ -2613,77 +2615,80 @@ static void ggml_vec_dot_q3_0_q8_0(const int n, float * restrict s, const void *
2613
2615
#if defined(__AVX2__ )
2614
2616
// Initialize accumulator with zeros
2615
2617
__m128 acc = _mm_setzero_ps ();
2616
- for (int i = 0 ; i < nb ; i ++ ) {
2617
- // Compute combined scale for the block
2618
- const __m128 scale = _mm_set1_ps (GGML_FP16_TO_FP32 (x [i ].d ) * y [i /2 ].d );
2619
-
2620
- const __m256i shift_l = _mm256_set_epi64x (2 * 3 , 64 , 4 * 3 , 0 );
2621
- const __m256i shift_r = _mm256_set_epi64x ( 64 , 2 * 3 , 64 , 64 );
2622
-
2623
- __m256i bxx = _mm256_set1_epi64x (x [i ].qs );
2624
-
2625
- // legend: _=zero +=one .=don't care 0-f=3bit quantized values s=fp16 scale
2626
-
2627
- // shift the copies to be able to reach all values
2628
- // 255 192 128 64 0
2629
- // | | | |
2630
- // sssssfedcba9876543210sssssfedcba9876543210sssssfedcba9876543210sssssfedcba9876543210 in
2631
- // sssfedcba9876543210_______________________sfedcba9876543210____sssssfedcba9876543210 shift left
2632
- // _______________________sssssfedcba98765432__________________________________________ shift right
2633
- // sssfedcba9876543210____sssssfedcba98765432sfedcba9876543210____sssssfedcba9876543210 out
2634
- // ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^
2635
- // e b 6 3 _ . f a 7 2 c 9 4 1 _ . d 8 5 0
2636
- bxx = _mm256_or_si256 (_mm256_sllv_epi64 (bxx , shift_l ), _mm256_srlv_epi64 (bxx , shift_r ));
2637
-
2638
- // add to itself in masked places to shift some values left one bit
2639
- // 127 64 0
2640
- // | | | | | | | | | | | | | | | |
2641
- // ssssfffeeedddcccbbbaaa999888777666555444333222111000____________ssssssssssssssssfffeeedddcccbbbaaa999888777666555444333222111000 in
2642
- // _____________________++++____________________++++____________________________________++++____________________++++_______________ mask
2643
- // _____________________.999____________________.111____________________________________.ddd____________________.555_______________ masked
2644
- // .............ccc.....999.............444.....111....____________.....................ddd.............888.....555.............000 sum
2645
- //
2646
- // 255 192 128
2647
- // | | | | | | | | | | | | | | | |
2648
- // ssssssssssfffeeedddcccbbbaaa999888777666555444333222111000____________ssssssssssssssssfffeeedddcccbbbaaa999888777666555444333222 in
2649
- // _____________________++++____________________++++____________________________________++++____________________++++_______________ mask
2650
- // _____________________.bbb____________________.333____________________________________.fff____________________.777_______________ masked
2651
- // .............eee.....bbb.............666.....333..........____________...............fff.............aaa.....777.............222 sum
2652
- const __m256i doublemask = _mm256_set1_epi64x (0x078000078000 );
2653
- bxx = _mm256_add_epi64 (bxx , _mm256_and_si256 (doublemask , bxx ));
2654
-
2655
- // collect 16 bytes from 256 into 128 bits
2656
- const __m256i shufmask = _mm256_set_epi8 (
2657
- 5 ,14 ,-1 ,-1 ,13 , 3 ,-1 ,-1 , 2 ,11 ,-1 ,-1 ,10 , 0 ,-1 ,-1 ,
2658
- -1 ,-1 , 5 ,14 ,-1 ,-1 ,13 , 3 ,-1 ,-1 , 2 ,11 ,-1 ,-1 ,10 , 0 );
2659
- bxx = _mm256_shuffle_epi8 (bxx , shufmask );
2618
+ for (int i = 0 ; i < nb /2 ; i ++ ) {
2619
+ const __m128 scale_y = _mm_set1_ps (y [i ].d );
2620
+ for (int u = 0 ; u < 2 ; u ++ ) { // let the compiler unroll this
2621
+ // Compute combined scale for the block
2622
+ const __m128 scale_x = _mm_set1_ps (GGML_FP16_TO_FP32 (x [i * 2 + u ].d ));
2623
+ const __m128 scale = _mm_mul_ps (scale_x , scale_y );
2624
+
2625
+ __m256i bxx = _mm256_set1_epi64x (x [i * 2 + u ].qs );
2626
+
2627
+ // legend: _=zero +=one .=don't care 0-f=3bit quantized values s=fp16 scale
2628
+
2629
+ // shift the copies to be able to reach all values
2630
+ // 255 192 128 64 0
2631
+ // | | | |
2632
+ // sssssfedcba9876543210sssssfedcba9876543210sssssfedcba9876543210sssssfedcba9876543210 in
2633
+ // sssfedcba9876543210_______________________sfedcba9876543210____sssssfedcba9876543210 shift left
2634
+ // _______________________sssssfedcba98765432__________________________________________ shift right
2635
+ // sssfedcba9876543210____sssssfedcba98765432sfedcba9876543210____sssssfedcba9876543210 out
2636
+ // ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^
2637
+ // e b 6 3 _ . f a 7 2 c 9 4 1 _ . d 8 5 0
2638
+ const __m256i shift_l = _mm256_set_epi64x (2 * 3 , 64 , 4 * 3 , 0 );
2639
+ const __m256i shift_r = _mm256_set_epi64x ( 64 , 2 * 3 , 64 , 64 );
2640
+ bxx = _mm256_or_si256 (_mm256_sllv_epi64 (bxx , shift_l ), _mm256_srlv_epi64 (bxx , shift_r ));
2641
+
2642
+ // add to itself in masked places to shift some values left one bit
2643
+ // 127 64 0
2644
+ // | | | | | | | | | | | | | | | |
2645
+ // ssssfffeeedddcccbbbaaa999888777666555444333222111000____________ssssssssssssssssfffeeedddcccbbbaaa999888777666555444333222111000 in
2646
+ // _____________________++++____________________++++____________________________________++++____________________++++_______________ mask
2647
+ // _____________________.999____________________.111____________________________________.ddd____________________.555_______________ masked
2648
+ // .............ccc.....999.............444.....111....____________.....................ddd.............888.....555.............000 sum
2649
+ //
2650
+ // 255 192 128
2651
+ // | | | | | | | | | | | | | | | |
2652
+ // ssssssssssfffeeedddcccbbbaaa999888777666555444333222111000____________ssssssssssssssssfffeeedddcccbbbaaa999888777666555444333222 in
2653
+ // _____________________++++____________________++++____________________________________++++____________________++++_______________ mask
2654
+ // _____________________.bbb____________________.333____________________________________.fff____________________.777_______________ masked
2655
+ // .............eee.....bbb.............666.....333..........____________...............fff.............aaa.....777.............222 sum
2656
+ const __m256i doublemask = _mm256_set1_epi64x (0x078000078000 );
2657
+ bxx = _mm256_add_epi64 (bxx , _mm256_and_si256 (doublemask , bxx ));
2658
+
2659
+ // collect 16 bytes from 256 into 128 bits
2660
+ const __m256i shufmask = _mm256_set_epi8 (
2661
+ 5 ,14 ,-1 ,-1 ,13 , 3 ,-1 ,-1 , 2 ,11 ,-1 ,-1 ,10 , 0 ,-1 ,-1 ,
2662
+ -1 ,-1 , 5 ,14 ,-1 ,-1 ,13 , 3 ,-1 ,-1 , 2 ,11 ,-1 ,-1 ,10 , 0 );
2663
+ bxx = _mm256_shuffle_epi8 (bxx , shufmask );
2664
+
2665
+ __m128i bx = _mm_or_si128 (_mm256_castsi256_si128 (bxx ), _mm256_extracti128_si256 (bxx , 1 ));
2666
+
2667
+ const __m128i mask = _mm_set1_epi8 (7 );
2668
+ bx = _mm_and_si128 (mask , bx );
2669
+
2670
+ const __m128i off = _mm_set1_epi8 (4 );
2671
+ bx = _mm_sub_epi8 (bx , off );
2672
+
2673
+ const __m128i by = _mm_loadu_si128 ((const __m128i * )(y [i ].qs + u * QK3_0 ));
2660
2674
2661
- __m128i bx = _mm_or_si128 (_mm256_castsi256_si128 (bxx ), _mm256_extracti128_si256 (bxx , 1 ));
2662
-
2663
- const __m128i mask = _mm_set1_epi8 (7 );
2664
- bx = _mm_and_si128 (mask , bx );
2665
-
2666
- const __m128i off = _mm_set1_epi8 (4 );
2667
- bx = _mm_sub_epi8 (bx , off );
2668
-
2669
- const __m128i by = _mm_loadu_si128 ((const __m128i * )(y [i /2 ].qs + (i %2 )* QK3_0 ));
2670
-
2671
- // Get absolute values of x vectors
2672
- const __m128i ax = _mm_sign_epi8 (bx , bx );
2673
- // Sign the values of the y vectors
2674
- const __m128i sy = _mm_sign_epi8 (by , bx );
2675
- // Perform multiplication and create 16-bit values
2676
- const __m128i dot = _mm_maddubs_epi16 (ax , sy );
2675
+ // Get absolute values of x vectors
2676
+ const __m128i ax = _mm_sign_epi8 (bx , bx );
2677
+ // Sign the values of the y vectors
2678
+ const __m128i sy = _mm_sign_epi8 (by , bx );
2679
+ // Perform multiplication and create 16-bit values
2680
+ const __m128i dot = _mm_maddubs_epi16 (ax , sy );
2677
2681
2678
- // Convert int16_t to int32_t by adding pairwise
2679
- const __m128i ones = _mm_set1_epi16 (1 );
2680
- __m128i i32 = _mm_madd_epi16 (dot , ones );
2682
+ // Convert int16_t to int32_t by adding pairwise
2683
+ const __m128i ones = _mm_set1_epi16 (1 );
2684
+ __m128i i32 = _mm_madd_epi16 (dot , ones );
2681
2685
2682
- // Convert int32_t to float
2683
- const __m128 p = _mm_cvtepi32_ps (i32 );
2686
+ // Convert int32_t to float
2687
+ const __m128 p = _mm_cvtepi32_ps (i32 );
2684
2688
2685
- // Apply the scale, and accumulate
2686
- acc = _mm_fmadd_ps (scale , p , acc );
2689
+ // Apply the scale, and accumulate
2690
+ acc = _mm_fmadd_ps (scale , p , acc );
2691
+ }
2687
2692
}
2688
2693
2689
2694
// Return horizontal sum of the acc vector
0 commit comments