Skip to content

Commit c29ab90

Browse files
committed
Q2 AVX2: do two blocks at a time, by @slaren
1 parent 6fc51a8 commit c29ab90

File tree

1 file changed

+50
-26
lines changed

1 file changed

+50
-26
lines changed

ggml.c

Lines changed: 50 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -488,6 +488,34 @@ static inline __m128i bytes_from_nibbles_16(const uint8_t * rsi)
488488
}
489489

490490
#if __AVX2__ || __AVX512F__
491+
// Unpack 32 2-bit fields into 32 bytes
492+
// The output vector contains 32 bytes, each one in [ 0 .. 3 ] interval
493+
static inline __m256i bytes_from_crumbs(uint32_t packed_hi, uint32_t packed_lo) {
494+
__m128i bx_hi = _mm_set1_epi32(packed_hi);
495+
__m128i bx_lo = _mm_set1_epi32(packed_lo);
496+
__m256i bx = _mm256_set_m128i(bx_hi, bx_lo);
497+
498+
// shift counts to get all bit pairs in lowest position of each byte
499+
const __m256i shift256 = _mm256_set_epi32(6, 4, 2, 0,
500+
6, 4, 2, 0);
501+
bx = _mm256_srlv_epi32(bx, shift256);
502+
503+
const __m256i shufmask = _mm256_set_epi8(15,11, 7, 3,
504+
14,10, 6, 2,
505+
13, 9, 5, 1,
506+
12, 8, 4, 0,
507+
15,11, 7, 3,
508+
14,10, 6, 2,
509+
13, 9, 5, 1,
510+
12, 8, 4, 0);
511+
bx = _mm256_shuffle_epi8(bx, shufmask);
512+
513+
const __m256i mask = _mm256_set1_epi8(3);
514+
bx = _mm256_and_si256(mask, bx);
515+
516+
return bx;
517+
}
518+
491519
// Unpack 32 4-bit fields into 32 bytes
492520
// The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval
493521
static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi)
@@ -2500,6 +2528,7 @@ inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t
25002528
static void ggml_vec_dot_q2_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
25012529
assert(n % QK2_0 == 0);
25022530
const int nb = n / QK2_0;
2531+
assert(nb % 2 == 0);
25032532

25042533
const block_q2_0 * restrict x = vx;
25052534
const block_q8_0 * restrict y = vy;
@@ -2508,49 +2537,44 @@ static void ggml_vec_dot_q2_0_q8_0(const int n, float * restrict s, const void *
25082537

25092538
#if defined(__AVX2__)
25102539
// Initialize accumulator with zeros
2511-
__m128 acc = _mm_setzero_ps();
2512-
2513-
for (int i = 0; i < nb; i++) {
2514-
// Compute combined scale for the block
2515-
const __m128 scale = _mm_set1_ps(GGML_FP16_TO_FP32(x[i].d) * y[i/2].d);
2516-
2517-
__m128i bx = _mm_set1_epi32(x[i].qs);
2518-
2519-
// shift counts to get all bit pairs in lowest position of each byte
2520-
const __m128i shift128 = _mm_set_epi32(6, 4, 2, 0);
2521-
bx = _mm_srlv_epi32(bx, shift128);
2540+
__m256 acc = _mm256_setzero_ps();
25222541

2523-
const __m128i shufmask = _mm_set_epi8(15,11,7,3,14,10,6,2,13,9,5,1,12,8,4,0);
2524-
bx = _mm_shuffle_epi8(bx, shufmask);
2542+
for (int i = 0; i < nb; i += 2) {
2543+
__m256i bx = bytes_from_crumbs(x[i+1].qs, x[i].qs);
25252544

2526-
const __m128i mask = _mm_set1_epi8(3);
2527-
bx = _mm_and_si128(mask, bx);
2545+
// Compute combined scale for the block
2546+
const __m128 scale_lo = _mm_set1_ps(GGML_FP16_TO_FP32(x[i+0].d) * y[i/2].d);
2547+
const __m128 scale_hi = _mm_set1_ps(GGML_FP16_TO_FP32(x[i+1].d) * y[i/2].d);
2548+
const __m256 scale = _mm256_set_m128(scale_hi, scale_lo);
25282549

2529-
const __m128i off = _mm_set1_epi8(2);
2530-
bx = _mm_sub_epi8(bx, off);
2550+
const __m256i off = _mm256_set1_epi8(2);
2551+
bx = _mm256_sub_epi8(bx, off);
25312552

2532-
const __m128i by = _mm_loadu_si128((const __m128i *)(y[i/2].qs + (i%2)*QK2_0));
2553+
// Load y vector
2554+
const __m256i by = _mm256_loadu_si256((const __m256i *)y[i/2].qs);
25332555

25342556
// Get absolute values of x vectors
2535-
const __m128i ax = _mm_sign_epi8(bx, bx);
2557+
const __m256i ax = _mm256_sign_epi8(bx, bx);
25362558
// Sign the values of the y vectors
2537-
const __m128i sy = _mm_sign_epi8(by, bx);
2559+
const __m256i sy = _mm256_sign_epi8(by, bx);
25382560
// Perform multiplication and create 16-bit values
2539-
const __m128i dot = _mm_maddubs_epi16(ax, sy);
2561+
const __m256i dot = _mm256_maddubs_epi16(ax, sy);
25402562

25412563
// Convert int16_t to int32_t by adding pairwise
2542-
const __m128i ones = _mm_set1_epi16(1);
2543-
__m128i i32 = _mm_madd_epi16(dot, ones);
2564+
const __m256i ones = _mm256_set1_epi16(1);
2565+
__m256i i32 = _mm256_madd_epi16(ones, dot);
25442566

25452567
// Convert int32_t to float
2546-
const __m128 p = _mm_cvtepi32_ps(i32);
2568+
__m256 p = _mm256_cvtepi32_ps(i32);
25472569

25482570
// Apply the scale, and accumulate
2549-
acc = _mm_fmadd_ps(scale, p, acc);
2571+
acc = _mm256_fmadd_ps(scale, p, acc);
25502572
}
25512573

25522574
// Return horizontal sum of the acc vector
2553-
__m128 res = _mm_add_ps(acc, _mm_movehl_ps(acc, acc));
2575+
__m128 res = _mm256_extractf128_ps(acc, 1);
2576+
res = _mm_add_ps(res, _mm256_castps256_ps128(acc));
2577+
res = _mm_add_ps(res, _mm_movehl_ps(res, res));
25542578
res = _mm_add_ss(res, _mm_movehdup_ps(res));
25552579
sumf = _mm_cvtss_f32(res);
25562580
#else

0 commit comments

Comments
 (0)