Skip to content

Commit 440b8fb

Browse files
committed
Q2 AVX2: do two blocks at a time, by @slaren
1 parent 2629192 commit 440b8fb

File tree

1 file changed

+50
-28
lines changed

1 file changed

+50
-28
lines changed

ggml.c

Lines changed: 50 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -427,9 +427,35 @@ static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float);
427427
// quantization
428428
//
429429

430-
// AVX routines provided by GH user Const-me
431-
// ref: https://github.com/ggerganov/ggml/pull/27#issuecomment-1464934600
432430
#if __AVX2__ || __AVX512F__
431+
// Unpack 32 2-bit fields into 32 bytes
432+
// The output vector contains 32 bytes, each one in [ 0 .. 3 ] interval
433+
static inline __m256i bytesFromCrumbs(uint32_t packed_hi, uint32_t packed_lo) {
434+
__m128i bx_hi = _mm_set1_epi32(packed_hi);
435+
__m128i bx_lo = _mm_set1_epi32(packed_lo);
436+
__m256i bx = _mm256_set_m128i(bx_hi, bx_lo);
437+
438+
// shift counts to get all bit pairs in lowest position of each byte
439+
const __m256i shift256 = _mm256_set_epi32(6, 4, 2, 0,
440+
6, 4, 2, 0);
441+
bx = _mm256_srlv_epi32(bx, shift256);
442+
443+
const __m256i shufmask = _mm256_set_epi8(15,11, 7, 3,
444+
14,10, 6, 2,
445+
13, 9, 5, 1,
446+
12, 8, 4, 0,
447+
15,11, 7, 3,
448+
14,10, 6, 2,
449+
13, 9, 5, 1,
450+
12, 8, 4, 0);
451+
bx = _mm256_shuffle_epi8(bx, shufmask);
452+
453+
const __m256i mask = _mm256_set1_epi8(3);
454+
bx = _mm256_and_si256(mask, bx);
455+
456+
return bx;
457+
}
458+
433459
// Unpack 32 4-bit fields into 32 bytes
434460
// The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval
435461
static inline __m256i bytesFromNibbles( const uint8_t* rsi )
@@ -2368,6 +2394,7 @@ inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t
23682394
static void ggml_vec_dot_q2_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
23692395
assert(n % QK2_0 == 0);
23702396
const int nb = n / QK2_0;
2397+
assert(nb % 2 == 0);
23712398

23722399
const block_q2_0 * restrict x = vx;
23732400
const block_q8_0 * restrict y = vy;
@@ -2376,49 +2403,44 @@ static void ggml_vec_dot_q2_0_q8_0(const int n, float * restrict s, const void *
23762403

23772404
#if defined(__AVX2__)
23782405
// Initialize accumulator with zeros
2379-
__m128 acc = _mm_setzero_ps();
2380-
2381-
for (int i = 0; i < nb; i++) {
2382-
// Compute combined scale for the block
2383-
const __m128 scale = _mm_set1_ps(GGML_FP16_TO_FP32(x[i].d) * y[i/2].d);
2384-
2385-
__m128i bx = _mm_set1_epi32(x[i].qs);
2406+
__m256 acc = _mm256_setzero_ps();
23862407

2387-
// shift counts to get all bit pairs in lowest position of each byte
2388-
const __m128i shift128 = _mm_set_epi32(6, 4, 2, 0);
2389-
bx = _mm_srlv_epi32(bx, shift128);
2408+
for (int i = 0; i < nb; i += 2) {
2409+
__m256i bx = bytesFromCrumbs(x[i+1].qs, x[i].qs);
23902410

2391-
const __m128i shufmask = _mm_set_epi8(15,11,7,3,14,10,6,2,13,9,5,1,12,8,4,0);
2392-
bx = _mm_shuffle_epi8(bx, shufmask);
2411+
// Compute combined scale for the block
2412+
const __m128 scale_lo = _mm_set1_ps(GGML_FP16_TO_FP32(x[i+0].d) * y[i/2].d);
2413+
const __m128 scale_hi = _mm_set1_ps(GGML_FP16_TO_FP32(x[i+1].d) * y[i/2].d);
2414+
const __m256 scale = _mm256_set_m128(scale_hi, scale_lo);
23932415

2394-
const __m128i mask = _mm_set1_epi8(3);
2395-
bx = _mm_and_si128(mask, bx);
2416+
const __m256i off = _mm256_set1_epi8(2);
2417+
bx = _mm256_sub_epi8(bx, off);
23962418

2397-
const __m128i off = _mm_set1_epi8(2);
2398-
bx = _mm_sub_epi8(bx, off);
2399-
2400-
const __m128i by = _mm_loadu_si128((const __m128i *)(y[i/2].qs + (i%2)*QK2_0));
2419+
// Load y vector
2420+
const __m256i by = _mm256_loadu_si256((const __m256i *)y[i/2].qs);
24012421

24022422
// Get absolute values of x vectors
2403-
const __m128i ax = _mm_sign_epi8(bx, bx);
2423+
const __m256i ax = _mm256_sign_epi8(bx, bx);
24042424
// Sign the values of the y vectors
2405-
const __m128i sy = _mm_sign_epi8(by, bx);
2425+
const __m256i sy = _mm256_sign_epi8(by, bx);
24062426
// Perform multiplication and create 16-bit values
2407-
const __m128i dot = _mm_maddubs_epi16(ax, sy);
2427+
const __m256i dot = _mm256_maddubs_epi16(ax, sy);
24082428

24092429
// Convert int16_t to int32_t by adding pairwise
2410-
const __m128i ones = _mm_set1_epi16(1);
2411-
__m128i i32 = _mm_madd_epi16(dot, ones);
2430+
const __m256i ones = _mm256_set1_epi16(1);
2431+
__m256i i32 = _mm256_madd_epi16(ones, dot);
24122432

24132433
// Convert int32_t to float
2414-
const __m128 p = _mm_cvtepi32_ps(i32);
2434+
__m256 p = _mm256_cvtepi32_ps(i32);
24152435

24162436
// Apply the scale, and accumulate
2417-
acc = _mm_fmadd_ps(scale, p, acc);
2437+
acc = _mm256_fmadd_ps(scale, p, acc);
24182438
}
24192439

24202440
// Return horizontal sum of the acc vector
2421-
__m128 res = _mm_add_ps(acc, _mm_movehl_ps(acc, acc));
2441+
__m128 res = _mm256_extractf128_ps(acc, 1);
2442+
res = _mm_add_ps(res, _mm256_castps256_ps128(acc));
2443+
res = _mm_add_ps(res, _mm_movehl_ps(res, res));
24222444
res = _mm_add_ss(res, _mm_movehdup_ps(res));
24232445
sumf = _mm_cvtss_f32(res);
24242446
#else

0 commit comments

Comments
 (0)