@@ -488,6 +488,34 @@ static inline __m128i bytes_from_nibbles_16(const uint8_t * rsi)
488
488
}
489
489
490
490
#if __AVX2__ || __AVX512F__
491
+ // Unpack 32 2-bit fields into 32 bytes
492
+ // The output vector contains 32 bytes, each one in [ 0 .. 3 ] interval
493
+ static inline __m256i bytes_from_crumbs (uint32_t packed_hi , uint32_t packed_lo ) {
494
+ __m128i bx_hi = _mm_set1_epi32 (packed_hi );
495
+ __m128i bx_lo = _mm_set1_epi32 (packed_lo );
496
+ __m256i bx = _mm256_set_m128i (bx_hi , bx_lo );
497
+
498
+ // shift counts to get all bit pairs in lowest position of each byte
499
+ const __m256i shift256 = _mm256_set_epi32 (6 , 4 , 2 , 0 ,
500
+ 6 , 4 , 2 , 0 );
501
+ bx = _mm256_srlv_epi32 (bx , shift256 );
502
+
503
+ const __m256i shufmask = _mm256_set_epi8 (15 ,11 , 7 , 3 ,
504
+ 14 ,10 , 6 , 2 ,
505
+ 13 , 9 , 5 , 1 ,
506
+ 12 , 8 , 4 , 0 ,
507
+ 15 ,11 , 7 , 3 ,
508
+ 14 ,10 , 6 , 2 ,
509
+ 13 , 9 , 5 , 1 ,
510
+ 12 , 8 , 4 , 0 );
511
+ bx = _mm256_shuffle_epi8 (bx , shufmask );
512
+
513
+ const __m256i mask = _mm256_set1_epi8 (3 );
514
+ bx = _mm256_and_si256 (mask , bx );
515
+
516
+ return bx ;
517
+ }
518
+
491
519
// Unpack 32 4-bit fields into 32 bytes
492
520
// The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval
493
521
static inline __m256i bytes_from_nibbles_32 (const uint8_t * rsi )
@@ -2500,6 +2528,7 @@ inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t
2500
2528
static void ggml_vec_dot_q2_0_q8_0 (const int n , float * restrict s , const void * restrict vx , const void * restrict vy ) {
2501
2529
assert (n % QK2_0 == 0 );
2502
2530
const int nb = n / QK2_0 ;
2531
+ assert (nb % 2 == 0 );
2503
2532
2504
2533
const block_q2_0 * restrict x = vx ;
2505
2534
const block_q8_0 * restrict y = vy ;
@@ -2508,49 +2537,44 @@ static void ggml_vec_dot_q2_0_q8_0(const int n, float * restrict s, const void *
2508
2537
2509
2538
#if defined(__AVX2__ )
2510
2539
// Initialize accumulator with zeros
2511
- __m128 acc = _mm_setzero_ps ();
2512
-
2513
- for (int i = 0 ; i < nb ; i ++ ) {
2514
- // Compute combined scale for the block
2515
- const __m128 scale = _mm_set1_ps (GGML_FP16_TO_FP32 (x [i ].d ) * y [i /2 ].d );
2516
-
2517
- __m128i bx = _mm_set1_epi32 (x [i ].qs );
2518
-
2519
- // shift counts to get all bit pairs in lowest position of each byte
2520
- const __m128i shift128 = _mm_set_epi32 (6 , 4 , 2 , 0 );
2521
- bx = _mm_srlv_epi32 (bx , shift128 );
2540
+ __m256 acc = _mm256_setzero_ps ();
2522
2541
2523
- const __m128i shufmask = _mm_set_epi8 ( 15 , 11 , 7 , 3 , 14 , 10 , 6 , 2 , 13 , 9 , 5 , 1 , 12 , 8 , 4 , 0 );
2524
- bx = _mm_shuffle_epi8 ( bx , shufmask );
2542
+ for ( int i = 0 ; i < nb ; i += 2 ) {
2543
+ __m256i bx = bytes_from_crumbs ( x [ i + 1 ]. qs , x [ i ]. qs );
2525
2544
2526
- const __m128i mask = _mm_set1_epi8 (3 );
2527
- bx = _mm_and_si128 (mask , bx );
2545
+ // Compute combined scale for the block
2546
+ const __m128 scale_lo = _mm_set1_ps (GGML_FP16_TO_FP32 (x [i + 0 ].d ) * y [i /2 ].d );
2547
+ const __m128 scale_hi = _mm_set1_ps (GGML_FP16_TO_FP32 (x [i + 1 ].d ) * y [i /2 ].d );
2548
+ const __m256 scale = _mm256_set_m128 (scale_hi , scale_lo );
2528
2549
2529
- const __m128i off = _mm_set1_epi8 (2 );
2530
- bx = _mm_sub_epi8 (bx , off );
2550
+ const __m256i off = _mm256_set1_epi8 (2 );
2551
+ bx = _mm256_sub_epi8 (bx , off );
2531
2552
2532
- const __m128i by = _mm_loadu_si128 ((const __m128i * )(y [i /2 ].qs + (i %2 )* QK2_0 ));
2553
+ // Load y vector
2554
+ const __m256i by = _mm256_loadu_si256 ((const __m256i * )y [i /2 ].qs );
2533
2555
2534
2556
// Get absolute values of x vectors
2535
- const __m128i ax = _mm_sign_epi8 (bx , bx );
2557
+ const __m256i ax = _mm256_sign_epi8 (bx , bx );
2536
2558
// Sign the values of the y vectors
2537
- const __m128i sy = _mm_sign_epi8 (by , bx );
2559
+ const __m256i sy = _mm256_sign_epi8 (by , bx );
2538
2560
// Perform multiplication and create 16-bit values
2539
- const __m128i dot = _mm_maddubs_epi16 (ax , sy );
2561
+ const __m256i dot = _mm256_maddubs_epi16 (ax , sy );
2540
2562
2541
2563
// Convert int16_t to int32_t by adding pairwise
2542
- const __m128i ones = _mm_set1_epi16 (1 );
2543
- __m128i i32 = _mm_madd_epi16 ( dot , ones );
2564
+ const __m256i ones = _mm256_set1_epi16 (1 );
2565
+ __m256i i32 = _mm256_madd_epi16 ( ones , dot );
2544
2566
2545
2567
// Convert int32_t to float
2546
- const __m128 p = _mm_cvtepi32_ps (i32 );
2568
+ __m256 p = _mm256_cvtepi32_ps (i32 );
2547
2569
2548
2570
// Apply the scale, and accumulate
2549
- acc = _mm_fmadd_ps (scale , p , acc );
2571
+ acc = _mm256_fmadd_ps (scale , p , acc );
2550
2572
}
2551
2573
2552
2574
// Return horizontal sum of the acc vector
2553
- __m128 res = _mm_add_ps (acc , _mm_movehl_ps (acc , acc ));
2575
+ __m128 res = _mm256_extractf128_ps (acc , 1 );
2576
+ res = _mm_add_ps (res , _mm256_castps256_ps128 (acc ));
2577
+ res = _mm_add_ps (res , _mm_movehl_ps (res , res ));
2554
2578
res = _mm_add_ss (res , _mm_movehdup_ps (res ));
2555
2579
sumf = _mm_cvtss_f32 (res );
2556
2580
#else
0 commit comments