Skip to content

Commit 66a865b

Browse files
committed
A faster version for Q4_1 x Q8_0 dot products
The idea nehind being that Q8_0 quantized values get used many times in the matrix multiplications where they are involved. In the current implementations, when we are evaluating the dot products, we need to compute the sum of the quants in the Q8_0 vector, so the same operation is repeated many times. Here we pre-compute the sum during Q8_0 quantization, store it in the now modified block_q8_0 struct, and then reuse this result in the subsequent dot products. In a synthetic benchmark (just compute a bunch of dot products), this change speeds up the Q4_1 * Q8_0 dot product by 80%, making the performance identical to Q4_0 * Q8_0. In practical application, I see a ~15% gain in speed for token prediction on M2, and ~5% gain on Ryzen 7950X. The speed gain in the prompt evaluation is much bigger (around 50%). I have only done the change for the scalar version, ARM_NEON, and AVX2, so we still need an AVX implementation.
1 parent d40fded commit 66a865b

File tree

3 files changed

+268
-37
lines changed

3 files changed

+268
-37
lines changed

ggml.c

Lines changed: 91 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -657,9 +657,10 @@ static_assert(sizeof(block_q4_3) == 2 * sizeof(ggml_fp16_t) + QK4_3 / 2, "wrong
657657
#define QK8_0 32
658658
typedef struct {
659659
float d; // delta
660+
float s; // d * sum(qs[i])
660661
int8_t qs[QK8_0]; // quants
661662
} block_q8_0;
662-
static_assert(sizeof(block_q8_0) == sizeof(float) + QK8_0, "wrong q8_0 block size/padding");
663+
static_assert(sizeof(block_q8_0) == 2*sizeof(float) + QK8_0, "wrong q8_0 block size/padding");
663664

664665

665666
// reference implementation for deterministic creation of model files
@@ -1299,10 +1300,13 @@ static void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * r
12991300

13001301
y[i].d = d;
13011302

1303+
int sum = 0;
13021304
for (int l = 0; l < QK8_0; ++l) {
13031305
const float v = x[i*QK8_0 + l]*id;
13041306
y[i].qs[l] = roundf(v);
1307+
sum += y[i].qs[l];
13051308
}
1309+
y[i].s = d * sum;
13061310
}
13071311
}
13081312

@@ -1332,6 +1336,8 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int
13321336

13331337
y[i].d = d;
13341338

1339+
int32x4_t accv = vdupq_n_s32(0);
1340+
13351341
for (int l = 0; l < 8; l++) {
13361342
const float32x4_t v = vmulq_n_f32(srcv[l], id);
13371343
const int32x4_t vi = vcvtnq_s32_f32(v);
@@ -1340,7 +1346,11 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int
13401346
y[i].qs[4*l + 1] = vgetq_lane_s32(vi, 1);
13411347
y[i].qs[4*l + 2] = vgetq_lane_s32(vi, 2);
13421348
y[i].qs[4*l + 3] = vgetq_lane_s32(vi, 3);
1349+
1350+
accv = vaddq_s32(accv, vi);
13431351
}
1352+
int32_t sum = vaddvq_s32(accv);
1353+
y[i].s = d * sum;
13441354
}
13451355
#elif defined(__AVX2__) || defined(__AVX__)
13461356
for (int i = 0; i < nb; i++) {
@@ -1388,6 +1398,16 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int
13881398
__m256i i3 = _mm256_cvtps_epi32( v3 );
13891399

13901400
#if defined(__AVX2__)
1401+
1402+
// Compute the sum of the quants
1403+
// There is not better way of doing this???
1404+
__m256i acc = _mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3));
1405+
__m128i sum128 = _mm_add_epi32(_mm256_castsi256_si128(acc), _mm256_extracti128_si256(acc, 1));
1406+
__m128i hi64 = _mm_unpackhi_epi64(sum128, sum128);
1407+
__m128i sum64 = _mm_add_epi32(hi64, sum128);
1408+
__m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1));
1409+
y[i].s = d * _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32));
1410+
13911411
// Convert int32 to int16
13921412
i0 = _mm256_packs_epi32( i0, i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15
13931413
i2 = _mm256_packs_epi32( i2, i3 ); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31
@@ -1430,6 +1450,14 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int
14301450
// scalar
14311451
quantize_row_q8_0_reference(x, y, k);
14321452
#endif
1453+
#if defined __AVX__
1454+
// TODO: vectorize this
1455+
for (int i=0; i<nb; ++i) {
1456+
int sum = 0;
1457+
for (int l=0; l<QK8_0; ++l) sum += y[i].qs[l];
1458+
y[i].s = y[i].d * sum;
1459+
}
1460+
#endif
14331461
}
14341462

14351463
static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, int k) {
@@ -2372,14 +2400,18 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
23722400
float32x4_t sumv0 = vdupq_n_f32(0.0f);
23732401
float32x4_t sumv1 = vdupq_n_f32(0.0f);
23742402

2403+
float sum8 = 0;
2404+
23752405
for (int i = 0; i < nb; i += 2) {
23762406
const block_q4_0 * restrict x0 = &x[i + 0];
23772407
const block_q4_0 * restrict x1 = &x[i + 1];
23782408
const block_q8_0 * restrict y0 = &y[i + 0];
23792409
const block_q8_0 * restrict y1 = &y[i + 1];
23802410

2411+
sum8 += x0->d * y0->s + x1->d * y1->s;
2412+
23812413
const uint8x16_t m4b = vdupq_n_u8(0xf);
2382-
const int8x16_t s8b = vdupq_n_s8(0x8);
2414+
//const int8x16_t s8b = vdupq_n_s8(0x8);
23832415

23842416
const uint8x16_t v0_0 = vld1q_u8(x0->qs);
23852417
const uint8x16_t v0_1 = vld1q_u8(x1->qs);
@@ -2391,10 +2423,10 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
23912423
const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
23922424

23932425
// sub 8
2394-
const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b);
2395-
const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b);
2396-
const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b);
2397-
const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b);
2426+
//const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b);
2427+
//const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b);
2428+
//const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b);
2429+
//const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b);
23982430

23992431
// load y
24002432
const int8x16_t v1_0l = vld1q_s8(y0->qs);
@@ -2410,21 +2442,31 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
24102442

24112443
#if defined(__ARM_FEATURE_DOTPROD)
24122444
// dot product into int32x4_t
2413-
const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0ls), v0_0hs, v1_0hs);
2414-
const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1ls), v0_1hs, v1_1hs);
2445+
//const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0ls), v0_0hs, v1_0hs);
2446+
//const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1ls), v0_1hs, v1_1hs);
2447+
const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0l, v1_0ls), v0_0h, v1_0hs);
2448+
const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1l, v1_1ls), v0_1h, v1_1hs);
24152449

24162450
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), x0->d*y0->d);
24172451
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), x1->d*y1->d);
24182452
#else
2419-
const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0ls));
2420-
const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0ls));
2421-
const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hs), vget_low_s8 (v1_0hs));
2422-
const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0hs));
2453+
//const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0ls));
2454+
//const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0ls));
2455+
//const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hs), vget_low_s8 (v1_0hs));
2456+
//const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0hs));
2457+
const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0l), vget_low_s8 (v1_0ls));
2458+
const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0l), vget_high_s8(v1_0ls));
2459+
const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0h), vget_low_s8 (v1_0hs));
2460+
const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0h), vget_high_s8(v1_0hs));
24232461

2424-
const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1ls), vget_low_s8 (v1_1ls));
2425-
const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1ls));
2426-
const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hs), vget_low_s8 (v1_1hs));
2427-
const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1hs));
2462+
//const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1ls), vget_low_s8 (v1_1ls));
2463+
//const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1ls));
2464+
//const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hs), vget_low_s8 (v1_1hs));
2465+
//const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1hs));
2466+
const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1l), vget_low_s8 (v1_1ls));
2467+
const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1l), vget_high_s8(v1_1ls));
2468+
const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1h), vget_low_s8 (v1_1hs));
2469+
const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1h), vget_high_s8(v1_1hs));
24282470

24292471
const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
24302472
const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
@@ -2436,7 +2478,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
24362478
#endif
24372479
}
24382480

2439-
sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
2481+
sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) - 8 * sum8;
24402482
#elif defined(__AVX2__)
24412483
// Initialize accumulator with zeros
24422484
__m256 acc = _mm256_setzero_ps();
@@ -2569,12 +2611,16 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
25692611
float32x4_t sumv0 = vdupq_n_f32(0.0f);
25702612
float32x4_t sumv1 = vdupq_n_f32(0.0f);
25712613

2614+
float summs = 0;
2615+
25722616
for (int i = 0; i < nb; i += 2) {
25732617
const block_q4_1 * restrict x0 = &x[i + 0];
25742618
const block_q4_1 * restrict x1 = &x[i + 1];
25752619
const block_q8_0 * restrict y0 = &y[i + 0];
25762620
const block_q8_0 * restrict y1 = &y[i + 1];
25772621

2622+
summs += x0->m * y0->s + x1->m * y1->s;
2623+
25782624
const uint8x16_t m4b = vdupq_n_u8(0xf);
25792625

25802626
const uint8x16_t v0_0 = vld1q_u8(x0->qs);
@@ -2598,16 +2644,18 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
25982644
const int8x16_t v1_1ls = vuzp1q_s8(v1_1l, v1_1h);
25992645
const int8x16_t v1_1hs = vuzp2q_s8(v1_1l, v1_1h);
26002646

2601-
const int16x8_t s0i = vaddq_s16(
2602-
vaddq_s16(vmovl_s8(vget_low_s8(v1_0ls)), vmovl_s8(vget_high_s8(v1_0ls))),
2603-
vaddq_s16(vmovl_s8(vget_low_s8(v1_0hs)), vmovl_s8(vget_high_s8(v1_0hs))));
2647+
// We no longer need this. We have computed the sum of the y quants during quantization,
2648+
// so we get the same as these via the scalar instruction above (summs += x0->m * y0->s + x1->m * y1->s)
2649+
//const int16x8_t s0i = vaddq_s16(
2650+
// vaddq_s16(vmovl_s8(vget_low_s8(v1_0ls)), vmovl_s8(vget_high_s8(v1_0ls))),
2651+
// vaddq_s16(vmovl_s8(vget_low_s8(v1_0hs)), vmovl_s8(vget_high_s8(v1_0hs))));
26042652

2605-
const int16x8_t s1i = vaddq_s16(
2606-
vaddq_s16(vmovl_s8(vget_low_s8(v1_1ls)), vmovl_s8(vget_high_s8(v1_1ls))),
2607-
vaddq_s16(vmovl_s8(vget_low_s8(v1_1hs)), vmovl_s8(vget_high_s8(v1_1hs))));
2653+
//const int16x8_t s1i = vaddq_s16(
2654+
// vaddq_s16(vmovl_s8(vget_low_s8(v1_1ls)), vmovl_s8(vget_high_s8(v1_1ls))),
2655+
// vaddq_s16(vmovl_s8(vget_low_s8(v1_1hs)), vmovl_s8(vget_high_s8(v1_1hs))));
26082656

2609-
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddl_s16(vget_low_s16(s0i), vget_high_s16(s0i))), x0->m*y0->d);
2610-
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddl_s16(vget_low_s16(s1i), vget_high_s16(s1i))), x1->m*y1->d);
2657+
//sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddl_s16(vget_low_s16(s0i), vget_high_s16(s0i))), x0->m*y0->d);
2658+
//sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddl_s16(vget_low_s16(s1i), vget_high_s16(s1i))), x1->m*y1->d);
26112659

26122660
#if defined(__ARM_FEATURE_DOTPROD)
26132661
// dot product into int32x4_t
@@ -2637,24 +2685,28 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
26372685
#endif
26382686
}
26392687

2640-
sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
2688+
sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs;
26412689
#elif defined(__AVX2__)
26422690
// Initialize accumulator with zeros
26432691
__m256 acc = _mm256_setzero_ps();
26442692

2693+
float summs = 0;
2694+
26452695
// Main loop
26462696
for (int i = 0; i < nb; ++i) {
26472697
const float * d0 = &x[i].d;
26482698
const float * d1 = &y[i].d;
2649-
const float * m0 = &x[i].m;
2699+
//const float * m0 = &x[i].m;
2700+
2701+
summs += x[i].m * y[i].s;
26502702

26512703
const __m256 d0v = _mm256_broadcast_ss( d0 );
26522704
const __m256 d1v = _mm256_broadcast_ss( d1 );
2653-
const __m256 m0v = _mm256_broadcast_ss( m0 );
2705+
//const __m256 m0v = _mm256_broadcast_ss( m0 );
26542706

26552707
// Compute combined scales
26562708
const __m256 d0d1 = _mm256_mul_ps( d0v, d1v );
2657-
const __m256 d1m0 = _mm256_mul_ps( d1v, m0v );
2709+
//const __m256 d1m0 = _mm256_mul_ps( d1v, m0v );
26582710

26592711
// Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes
26602712
const __m256i bx = bytes_from_nibbles_32(x[i].qs);
@@ -2677,14 +2729,16 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
26772729
// Accumulate d0*d1*x*y
26782730
acc = _mm256_fmadd_ps( d0d1, xy, acc );
26792731

2680-
// Compute sum of y values
2681-
const __m256i y16_l = _mm256_cvtepi8_epi16( _mm256_castsi256_si128( by ) );
2682-
const __m256i y16_h = _mm256_cvtepi8_epi16( _mm256_extracti128_si256( by, 1 ) );
2683-
const __m256i ysumi = _mm256_madd_epi16( _mm256_add_epi16(y16_l, y16_h), ones );
2684-
const __m256 ysum = _mm256_cvtepi32_ps( ysumi );
2732+
// We no longer need this. We have computed the sum of the y quants during quantization,
2733+
// so we get the same as these via the single scalar instruction above (summs += x[i].m * y[i].s)
2734+
//// Compute sum of y values
2735+
//const __m256i y16_l = _mm256_cvtepi8_epi16( _mm256_castsi256_si128( by ) );
2736+
//const __m256i y16_h = _mm256_cvtepi8_epi16( _mm256_extracti128_si256( by, 1 ) );
2737+
//const __m256i ysumi = _mm256_madd_epi16( _mm256_add_epi16(y16_l, y16_h), ones );
2738+
//const __m256 ysum = _mm256_cvtepi32_ps( ysumi );
26852739

2686-
// Accumulate d1*m0*y
2687-
acc = _mm256_fmadd_ps( d1m0, ysum, acc );
2740+
//// Accumulate d1*m0*y
2741+
//acc = _mm256_fmadd_ps( d1m0, ysum, acc );
26882742
}
26892743

26902744
// Return horizontal sum of the acc vector
@@ -2693,7 +2747,7 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
26932747
res = _mm_add_ps( res, _mm_movehl_ps( res, res ) );
26942748
res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
26952749

2696-
sumf = _mm_cvtss_f32( res );
2750+
sumf = _mm_cvtss_f32( res ) + summs;
26972751
#else
26982752
// scalar
26992753
for (int i = 0; i < nb; i++) {

pocs/vdot/CMakeLists.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,8 @@ set(TARGET vdot)
22
add_executable(${TARGET} vdot.cpp)
33
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
44
target_compile_features(${TARGET} PRIVATE cxx_std_11)
5+
6+
set(TARGET q8dot)
7+
add_executable(${TARGET} q8dot.cpp)
8+
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
9+
target_compile_features(${TARGET} PRIVATE cxx_std_11)

0 commit comments

Comments
 (0)