Skip to content

Commit 8c90a86

Browse files
committed
More AVX2 optimizations
1 parent c29ab90 commit 8c90a86

File tree

1 file changed

+78
-73
lines changed

1 file changed

+78
-73
lines changed

ggml.c

Lines changed: 78 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -2539,19 +2539,20 @@ static void ggml_vec_dot_q2_0_q8_0(const int n, float * restrict s, const void *
25392539
// Initialize accumulator with zeros
25402540
__m256 acc = _mm256_setzero_ps();
25412541

2542-
for (int i = 0; i < nb; i += 2) {
2543-
__m256i bx = bytes_from_crumbs(x[i+1].qs, x[i].qs);
2542+
for (int i = 0; i < nb/2; i++) {
2543+
__m256i bx = bytes_from_crumbs(x[i*2+1].qs, x[i*2].qs);
25442544

25452545
// Compute combined scale for the block
2546-
const __m128 scale_lo = _mm_set1_ps(GGML_FP16_TO_FP32(x[i+0].d) * y[i/2].d);
2547-
const __m128 scale_hi = _mm_set1_ps(GGML_FP16_TO_FP32(x[i+1].d) * y[i/2].d);
2548-
const __m256 scale = _mm256_set_m128(scale_hi, scale_lo);
2546+
const __m128 scale_lo = _mm_set1_ps(GGML_FP16_TO_FP32(x[i*2+0].d));
2547+
const __m128 scale_hi = _mm_set1_ps(GGML_FP16_TO_FP32(x[i*2+1].d));
2548+
__m256 scale = _mm256_set_m128(scale_hi, scale_lo);
2549+
scale = _mm256_mul_ps(scale, _mm256_broadcast_ss(&y[i].d));
25492550

25502551
const __m256i off = _mm256_set1_epi8(2);
25512552
bx = _mm256_sub_epi8(bx, off);
25522553

25532554
// Load y vector
2554-
const __m256i by = _mm256_loadu_si256((const __m256i *)y[i/2].qs);
2555+
const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
25552556

25562557
// Get absolute values of x vectors
25572558
const __m256i ax = _mm256_sign_epi8(bx, bx);
@@ -2604,6 +2605,7 @@ static void ggml_vec_dot_q2_0_q8_0(const int n, float * restrict s, const void *
26042605
static void ggml_vec_dot_q3_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
26052606
assert(n % QK3_0 == 0);
26062607
const int nb = n / QK3_0;
2608+
assert(nb % 2 == 0);
26072609

26082610
const block_q3_0 * restrict x = vx;
26092611
const block_q8_0 * restrict y = vy;
@@ -2613,77 +2615,80 @@ static void ggml_vec_dot_q3_0_q8_0(const int n, float * restrict s, const void *
26132615
#if defined(__AVX2__)
26142616
// Initialize accumulator with zeros
26152617
__m128 acc = _mm_setzero_ps();
2616-
for (int i = 0; i < nb; i++) {
2617-
// Compute combined scale for the block
2618-
const __m128 scale = _mm_set1_ps(GGML_FP16_TO_FP32(x[i].d) * y[i/2].d);
2619-
2620-
const __m256i shift_l = _mm256_set_epi64x(2*3, 64, 4*3, 0);
2621-
const __m256i shift_r = _mm256_set_epi64x( 64, 2*3, 64, 64);
2622-
2623-
__m256i bxx = _mm256_set1_epi64x(x[i].qs);
2624-
2625-
// legend: _=zero +=one .=don't care 0-f=3bit quantized values s=fp16 scale
2626-
2627-
// shift the copies to be able to reach all values
2628-
// 255 192 128 64 0
2629-
// | | | |
2630-
// sssssfedcba9876543210sssssfedcba9876543210sssssfedcba9876543210sssssfedcba9876543210 in
2631-
// sssfedcba9876543210_______________________sfedcba9876543210____sssssfedcba9876543210 shift left
2632-
// _______________________sssssfedcba98765432__________________________________________ shift right
2633-
// sssfedcba9876543210____sssssfedcba98765432sfedcba9876543210____sssssfedcba9876543210 out
2634-
// ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^
2635-
// e b 6 3 _ . f a 7 2 c 9 4 1 _ . d 8 5 0
2636-
bxx = _mm256_or_si256(_mm256_sllv_epi64(bxx, shift_l), _mm256_srlv_epi64(bxx, shift_r));
2637-
2638-
// add to itself in masked places to shift some values left one bit
2639-
// 127 64 0
2640-
// | | | | | | | | | | | | | | | |
2641-
// ssssfffeeedddcccbbbaaa999888777666555444333222111000____________ssssssssssssssssfffeeedddcccbbbaaa999888777666555444333222111000 in
2642-
// _____________________++++____________________++++____________________________________++++____________________++++_______________ mask
2643-
// _____________________.999____________________.111____________________________________.ddd____________________.555_______________ masked
2644-
// .............ccc.....999.............444.....111....____________.....................ddd.............888.....555.............000 sum
2645-
//
2646-
// 255 192 128
2647-
// | | | | | | | | | | | | | | | |
2648-
// ssssssssssfffeeedddcccbbbaaa999888777666555444333222111000____________ssssssssssssssssfffeeedddcccbbbaaa999888777666555444333222 in
2649-
// _____________________++++____________________++++____________________________________++++____________________++++_______________ mask
2650-
// _____________________.bbb____________________.333____________________________________.fff____________________.777_______________ masked
2651-
// .............eee.....bbb.............666.....333..........____________...............fff.............aaa.....777.............222 sum
2652-
const __m256i doublemask = _mm256_set1_epi64x(0x078000078000);
2653-
bxx = _mm256_add_epi64(bxx, _mm256_and_si256(doublemask, bxx));
2654-
2655-
// collect 16 bytes from 256 into 128 bits
2656-
const __m256i shufmask = _mm256_set_epi8(
2657-
5,14,-1,-1,13, 3,-1,-1, 2,11,-1,-1,10, 0,-1,-1,
2658-
-1,-1, 5,14,-1,-1,13, 3,-1,-1, 2,11,-1,-1,10, 0);
2659-
bxx = _mm256_shuffle_epi8(bxx, shufmask);
2618+
for (int i = 0; i < nb/2; i++) {
2619+
const __m128 scale_y = _mm_set1_ps(y[i].d);
2620+
for (int u = 0; u < 2; u++) { // let the compiler unroll this
2621+
// Compute combined scale for the block
2622+
const __m128 scale_x = _mm_set1_ps(GGML_FP16_TO_FP32(x[i*2+u].d));
2623+
const __m128 scale = _mm_mul_ps(scale_x, scale_y);
2624+
2625+
__m256i bxx = _mm256_set1_epi64x(x[i*2+u].qs);
2626+
2627+
// legend: _=zero +=one .=don't care 0-f=3bit quantized values s=fp16 scale
2628+
2629+
// shift the copies to be able to reach all values
2630+
// 255 192 128 64 0
2631+
// | | | |
2632+
// sssssfedcba9876543210sssssfedcba9876543210sssssfedcba9876543210sssssfedcba9876543210 in
2633+
// sssfedcba9876543210_______________________sfedcba9876543210____sssssfedcba9876543210 shift left
2634+
// _______________________sssssfedcba98765432__________________________________________ shift right
2635+
// sssfedcba9876543210____sssssfedcba98765432sfedcba9876543210____sssssfedcba9876543210 out
2636+
// ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^
2637+
// e b 6 3 _ . f a 7 2 c 9 4 1 _ . d 8 5 0
2638+
const __m256i shift_l = _mm256_set_epi64x(2*3, 64, 4*3, 0);
2639+
const __m256i shift_r = _mm256_set_epi64x( 64, 2*3, 64, 64);
2640+
bxx = _mm256_or_si256(_mm256_sllv_epi64(bxx, shift_l), _mm256_srlv_epi64(bxx, shift_r));
2641+
2642+
// add to itself in masked places to shift some values left one bit
2643+
// 127 64 0
2644+
// | | | | | | | | | | | | | | | |
2645+
// ssssfffeeedddcccbbbaaa999888777666555444333222111000____________ssssssssssssssssfffeeedddcccbbbaaa999888777666555444333222111000 in
2646+
// _____________________++++____________________++++____________________________________++++____________________++++_______________ mask
2647+
// _____________________.999____________________.111____________________________________.ddd____________________.555_______________ masked
2648+
// .............ccc.....999.............444.....111....____________.....................ddd.............888.....555.............000 sum
2649+
//
2650+
// 255 192 128
2651+
// | | | | | | | | | | | | | | | |
2652+
// ssssssssssfffeeedddcccbbbaaa999888777666555444333222111000____________ssssssssssssssssfffeeedddcccbbbaaa999888777666555444333222 in
2653+
// _____________________++++____________________++++____________________________________++++____________________++++_______________ mask
2654+
// _____________________.bbb____________________.333____________________________________.fff____________________.777_______________ masked
2655+
// .............eee.....bbb.............666.....333..........____________...............fff.............aaa.....777.............222 sum
2656+
const __m256i doublemask = _mm256_set1_epi64x(0x078000078000);
2657+
bxx = _mm256_add_epi64(bxx, _mm256_and_si256(doublemask, bxx));
2658+
2659+
// collect 16 bytes from 256 into 128 bits
2660+
const __m256i shufmask = _mm256_set_epi8(
2661+
5,14,-1,-1,13, 3,-1,-1, 2,11,-1,-1,10, 0,-1,-1,
2662+
-1,-1, 5,14,-1,-1,13, 3,-1,-1, 2,11,-1,-1,10, 0);
2663+
bxx = _mm256_shuffle_epi8(bxx, shufmask);
2664+
2665+
__m128i bx = _mm_or_si128(_mm256_castsi256_si128(bxx), _mm256_extracti128_si256(bxx, 1));
2666+
2667+
const __m128i mask = _mm_set1_epi8(7);
2668+
bx = _mm_and_si128(mask, bx);
2669+
2670+
const __m128i off = _mm_set1_epi8(4);
2671+
bx = _mm_sub_epi8(bx, off);
2672+
2673+
const __m128i by = _mm_loadu_si128((const __m128i *)(y[i].qs + u*QK3_0));
26602674

2661-
__m128i bx = _mm_or_si128(_mm256_castsi256_si128(bxx), _mm256_extracti128_si256(bxx, 1));
2662-
2663-
const __m128i mask = _mm_set1_epi8(7);
2664-
bx = _mm_and_si128(mask, bx);
2665-
2666-
const __m128i off = _mm_set1_epi8(4);
2667-
bx = _mm_sub_epi8(bx, off);
2668-
2669-
const __m128i by = _mm_loadu_si128((const __m128i *)(y[i/2].qs + (i%2)*QK3_0));
2670-
2671-
// Get absolute values of x vectors
2672-
const __m128i ax = _mm_sign_epi8(bx, bx);
2673-
// Sign the values of the y vectors
2674-
const __m128i sy = _mm_sign_epi8(by, bx);
2675-
// Perform multiplication and create 16-bit values
2676-
const __m128i dot = _mm_maddubs_epi16(ax, sy);
2675+
// Get absolute values of x vectors
2676+
const __m128i ax = _mm_sign_epi8(bx, bx);
2677+
// Sign the values of the y vectors
2678+
const __m128i sy = _mm_sign_epi8(by, bx);
2679+
// Perform multiplication and create 16-bit values
2680+
const __m128i dot = _mm_maddubs_epi16(ax, sy);
26772681

2678-
// Convert int16_t to int32_t by adding pairwise
2679-
const __m128i ones = _mm_set1_epi16(1);
2680-
__m128i i32 = _mm_madd_epi16(dot, ones);
2682+
// Convert int16_t to int32_t by adding pairwise
2683+
const __m128i ones = _mm_set1_epi16(1);
2684+
__m128i i32 = _mm_madd_epi16(dot, ones);
26812685

2682-
// Convert int32_t to float
2683-
const __m128 p = _mm_cvtepi32_ps(i32);
2686+
// Convert int32_t to float
2687+
const __m128 p = _mm_cvtepi32_ps(i32);
26842688

2685-
// Apply the scale, and accumulate
2686-
acc = _mm_fmadd_ps(scale, p, acc);
2689+
// Apply the scale, and accumulate
2690+
acc = _mm_fmadd_ps(scale, p, acc);
2691+
}
26872692
}
26882693

26892694
// Return horizontal sum of the acc vector

0 commit comments

Comments
 (0)