Skip to content

Commit b51101a

Browse files
committed
A faster version for Q4_1 x Q8_0 dot products
The idea nehind being that Q8_0 quantized values get used many times in the matrix multiplications where they are involved. In the current implementations, when we are evaluating the dot products, we need to compute the sum of the quants in the Q8_0 vector, so the same operation is repeated many times. Here we pre-compute the sum during Q8_0 quantization, store it in the now modified block_q8_0 struct, and then reuse this result in the subsequent dot products. In a synthetic benchmark (just compute a bunch of dot products), this change speeds up the Q4_1 * Q8_0 dot product by 80%, making the performance identical to Q4_0 * Q8_0. In practical application, I see a ~15% gain in speed for token prediction on M2, and ~5% gain on Ryzen 7950X. The speed gain in the prompt evaluation is much bigger (around 50%). I have only done the change for the scalar version, ARM_NEON, and AVX2, so we still need an AVX implementation.
1 parent 02d6988 commit b51101a

File tree

3 files changed

+268
-37
lines changed

3 files changed

+268
-37
lines changed

ggml.c

Lines changed: 91 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -647,9 +647,10 @@ static_assert(sizeof(block_q4_2) == sizeof(ggml_fp16_t) + QK4_2 / 2, "wrong q4_2
647647
#define QK8_0 32
648648
typedef struct {
649649
float d; // delta
650+
float s; // d * sum(qs[i])
650651
int8_t qs[QK8_0]; // quants
651652
} block_q8_0;
652-
static_assert(sizeof(block_q8_0) == sizeof(float) + QK8_0, "wrong q8_0 block size/padding");
653+
static_assert(sizeof(block_q8_0) == 2*sizeof(float) + QK8_0, "wrong q8_0 block size/padding");
653654

654655

655656
// reference implementation for deterministic creation of model files
@@ -1247,10 +1248,13 @@ static void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * r
12471248

12481249
y[i].d = d;
12491250

1251+
int sum = 0;
12501252
for (int l = 0; l < QK8_0; ++l) {
12511253
const float v = x[i*QK8_0 + l]*id;
12521254
y[i].qs[l] = roundf(v);
1255+
sum += y[i].qs[l];
12531256
}
1257+
y[i].s = d * sum;
12541258
}
12551259
}
12561260

@@ -1280,6 +1284,8 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int
12801284

12811285
y[i].d = d;
12821286

1287+
int32x4_t accv = vdupq_n_s32(0);
1288+
12831289
for (int l = 0; l < 8; l++) {
12841290
const float32x4_t v = vmulq_n_f32(srcv[l], id);
12851291
const int32x4_t vi = vcvtnq_s32_f32(v);
@@ -1288,7 +1294,11 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int
12881294
y[i].qs[4*l + 1] = vgetq_lane_s32(vi, 1);
12891295
y[i].qs[4*l + 2] = vgetq_lane_s32(vi, 2);
12901296
y[i].qs[4*l + 3] = vgetq_lane_s32(vi, 3);
1297+
1298+
accv = vaddq_s32(accv, vi);
12911299
}
1300+
int32_t sum = vaddvq_s32(accv);
1301+
y[i].s = d * sum;
12921302
}
12931303
#elif defined(__AVX2__) || defined(__AVX__)
12941304
for (int i = 0; i < nb; i++) {
@@ -1336,6 +1346,16 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int
13361346
__m256i i3 = _mm256_cvtps_epi32( v3 );
13371347

13381348
#if defined(__AVX2__)
1349+
1350+
// Compute the sum of the quants
1351+
// There is not better way of doing this???
1352+
__m256i acc = _mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3));
1353+
__m128i sum128 = _mm_add_epi32(_mm256_castsi256_si128(acc), _mm256_extracti128_si256(acc, 1));
1354+
__m128i hi64 = _mm_unpackhi_epi64(sum128, sum128);
1355+
__m128i sum64 = _mm_add_epi32(hi64, sum128);
1356+
__m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1));
1357+
y[i].s = d * _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32));
1358+
13391359
// Convert int32 to int16
13401360
i0 = _mm256_packs_epi32( i0, i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15
13411361
i2 = _mm256_packs_epi32( i2, i3 ); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31
@@ -1378,6 +1398,14 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int
13781398
// scalar
13791399
quantize_row_q8_0_reference(x, y, k);
13801400
#endif
1401+
#if defined __AVX__
1402+
// TODO: vectorize this
1403+
for (int i=0; i<nb; ++i) {
1404+
int sum = 0;
1405+
for (int l=0; l<QK8_0; ++l) sum += y[i].qs[l];
1406+
y[i].s = y[i].d * sum;
1407+
}
1408+
#endif
13811409
}
13821410

13831411
static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, int k) {
@@ -2282,14 +2310,18 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
22822310
float32x4_t sumv0 = vdupq_n_f32(0.0f);
22832311
float32x4_t sumv1 = vdupq_n_f32(0.0f);
22842312

2313+
float sum8 = 0;
2314+
22852315
for (int i = 0; i < nb; i += 2) {
22862316
const block_q4_0 * restrict x0 = &x[i + 0];
22872317
const block_q4_0 * restrict x1 = &x[i + 1];
22882318
const block_q8_0 * restrict y0 = &y[i + 0];
22892319
const block_q8_0 * restrict y1 = &y[i + 1];
22902320

2321+
sum8 += x0->d * y0->s + x1->d * y1->s;
2322+
22912323
const uint8x16_t m4b = vdupq_n_u8(0xf);
2292-
const int8x16_t s8b = vdupq_n_s8(0x8);
2324+
//const int8x16_t s8b = vdupq_n_s8(0x8);
22932325

22942326
const uint8x16_t v0_0 = vld1q_u8(x0->qs);
22952327
const uint8x16_t v0_1 = vld1q_u8(x1->qs);
@@ -2301,10 +2333,10 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
23012333
const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
23022334

23032335
// sub 8
2304-
const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b);
2305-
const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b);
2306-
const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b);
2307-
const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b);
2336+
//const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b);
2337+
//const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b);
2338+
//const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b);
2339+
//const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b);
23082340

23092341
// load y
23102342
const int8x16_t v1_0l = vld1q_s8(y0->qs);
@@ -2320,21 +2352,31 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
23202352

23212353
#if defined(__ARM_FEATURE_DOTPROD)
23222354
// dot product into int32x4_t
2323-
const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0ls), v0_0hs, v1_0hs);
2324-
const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1ls), v0_1hs, v1_1hs);
2355+
//const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0ls), v0_0hs, v1_0hs);
2356+
//const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1ls), v0_1hs, v1_1hs);
2357+
const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0l, v1_0ls), v0_0h, v1_0hs);
2358+
const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1l, v1_1ls), v0_1h, v1_1hs);
23252359

23262360
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), x0->d*y0->d);
23272361
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), x1->d*y1->d);
23282362
#else
2329-
const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0ls));
2330-
const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0ls));
2331-
const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hs), vget_low_s8 (v1_0hs));
2332-
const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0hs));
2363+
//const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0ls));
2364+
//const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0ls));
2365+
//const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hs), vget_low_s8 (v1_0hs));
2366+
//const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0hs));
2367+
const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0l), vget_low_s8 (v1_0ls));
2368+
const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0l), vget_high_s8(v1_0ls));
2369+
const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0h), vget_low_s8 (v1_0hs));
2370+
const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0h), vget_high_s8(v1_0hs));
23332371

2334-
const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1ls), vget_low_s8 (v1_1ls));
2335-
const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1ls));
2336-
const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hs), vget_low_s8 (v1_1hs));
2337-
const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1hs));
2372+
//const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1ls), vget_low_s8 (v1_1ls));
2373+
//const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1ls));
2374+
//const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hs), vget_low_s8 (v1_1hs));
2375+
//const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1hs));
2376+
const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1l), vget_low_s8 (v1_1ls));
2377+
const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1l), vget_high_s8(v1_1ls));
2378+
const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1h), vget_low_s8 (v1_1hs));
2379+
const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1h), vget_high_s8(v1_1hs));
23382380

23392381
const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
23402382
const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
@@ -2346,7 +2388,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
23462388
#endif
23472389
}
23482390

2349-
sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
2391+
sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) - 8 * sum8;
23502392
#elif defined(__AVX2__)
23512393
// Initialize accumulator with zeros
23522394
__m256 acc = _mm256_setzero_ps();
@@ -2479,12 +2521,16 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
24792521
float32x4_t sumv0 = vdupq_n_f32(0.0f);
24802522
float32x4_t sumv1 = vdupq_n_f32(0.0f);
24812523

2524+
float summs = 0;
2525+
24822526
for (int i = 0; i < nb; i += 2) {
24832527
const block_q4_1 * restrict x0 = &x[i + 0];
24842528
const block_q4_1 * restrict x1 = &x[i + 1];
24852529
const block_q8_0 * restrict y0 = &y[i + 0];
24862530
const block_q8_0 * restrict y1 = &y[i + 1];
24872531

2532+
summs += x0->m * y0->s + x1->m * y1->s;
2533+
24882534
const uint8x16_t m4b = vdupq_n_u8(0xf);
24892535

24902536
const uint8x16_t v0_0 = vld1q_u8(x0->qs);
@@ -2508,16 +2554,18 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
25082554
const int8x16_t v1_1ls = vuzp1q_s8(v1_1l, v1_1h);
25092555
const int8x16_t v1_1hs = vuzp2q_s8(v1_1l, v1_1h);
25102556

2511-
const int16x8_t s0i = vaddq_s16(
2512-
vaddq_s16(vmovl_s8(vget_low_s8(v1_0ls)), vmovl_s8(vget_high_s8(v1_0ls))),
2513-
vaddq_s16(vmovl_s8(vget_low_s8(v1_0hs)), vmovl_s8(vget_high_s8(v1_0hs))));
2557+
// We no longer need this. We have computed the sum of the y quants during quantization,
2558+
// so we get the same as these via the scalar instruction above (summs += x0->m * y0->s + x1->m * y1->s)
2559+
//const int16x8_t s0i = vaddq_s16(
2560+
// vaddq_s16(vmovl_s8(vget_low_s8(v1_0ls)), vmovl_s8(vget_high_s8(v1_0ls))),
2561+
// vaddq_s16(vmovl_s8(vget_low_s8(v1_0hs)), vmovl_s8(vget_high_s8(v1_0hs))));
25142562

2515-
const int16x8_t s1i = vaddq_s16(
2516-
vaddq_s16(vmovl_s8(vget_low_s8(v1_1ls)), vmovl_s8(vget_high_s8(v1_1ls))),
2517-
vaddq_s16(vmovl_s8(vget_low_s8(v1_1hs)), vmovl_s8(vget_high_s8(v1_1hs))));
2563+
//const int16x8_t s1i = vaddq_s16(
2564+
// vaddq_s16(vmovl_s8(vget_low_s8(v1_1ls)), vmovl_s8(vget_high_s8(v1_1ls))),
2565+
// vaddq_s16(vmovl_s8(vget_low_s8(v1_1hs)), vmovl_s8(vget_high_s8(v1_1hs))));
25182566

2519-
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddl_s16(vget_low_s16(s0i), vget_high_s16(s0i))), x0->m*y0->d);
2520-
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddl_s16(vget_low_s16(s1i), vget_high_s16(s1i))), x1->m*y1->d);
2567+
//sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddl_s16(vget_low_s16(s0i), vget_high_s16(s0i))), x0->m*y0->d);
2568+
//sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddl_s16(vget_low_s16(s1i), vget_high_s16(s1i))), x1->m*y1->d);
25212569

25222570
#if defined(__ARM_FEATURE_DOTPROD)
25232571
// dot product into int32x4_t
@@ -2547,24 +2595,28 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
25472595
#endif
25482596
}
25492597

2550-
sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
2598+
sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs;
25512599
#elif defined(__AVX2__)
25522600
// Initialize accumulator with zeros
25532601
__m256 acc = _mm256_setzero_ps();
25542602

2603+
float summs = 0;
2604+
25552605
// Main loop
25562606
for (int i = 0; i < nb; ++i) {
25572607
const float * d0 = &x[i].d;
25582608
const float * d1 = &y[i].d;
2559-
const float * m0 = &x[i].m;
2609+
//const float * m0 = &x[i].m;
2610+
2611+
summs += x[i].m * y[i].s;
25602612

25612613
const __m256 d0v = _mm256_broadcast_ss( d0 );
25622614
const __m256 d1v = _mm256_broadcast_ss( d1 );
2563-
const __m256 m0v = _mm256_broadcast_ss( m0 );
2615+
//const __m256 m0v = _mm256_broadcast_ss( m0 );
25642616

25652617
// Compute combined scales
25662618
const __m256 d0d1 = _mm256_mul_ps( d0v, d1v );
2567-
const __m256 d1m0 = _mm256_mul_ps( d1v, m0v );
2619+
//const __m256 d1m0 = _mm256_mul_ps( d1v, m0v );
25682620

25692621
// Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes
25702622
const __m256i bx = bytesFromNibbles( x[i].qs );
@@ -2587,14 +2639,16 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
25872639
// Accumulate d0*d1*x*y
25882640
acc = _mm256_fmadd_ps( d0d1, xy, acc );
25892641

2590-
// Compute sum of y values
2591-
const __m256i y16_l = _mm256_cvtepi8_epi16( _mm256_castsi256_si128( by ) );
2592-
const __m256i y16_h = _mm256_cvtepi8_epi16( _mm256_extracti128_si256( by, 1 ) );
2593-
const __m256i ysumi = _mm256_madd_epi16( _mm256_add_epi16(y16_l, y16_h), ones );
2594-
const __m256 ysum = _mm256_cvtepi32_ps( ysumi );
2642+
// We no longer need this. We have computed the sum of the y quants during quantization,
2643+
// so we get the same as these via the single scalar instruction above (summs += x[i].m * y[i].s)
2644+
//// Compute sum of y values
2645+
//const __m256i y16_l = _mm256_cvtepi8_epi16( _mm256_castsi256_si128( by ) );
2646+
//const __m256i y16_h = _mm256_cvtepi8_epi16( _mm256_extracti128_si256( by, 1 ) );
2647+
//const __m256i ysumi = _mm256_madd_epi16( _mm256_add_epi16(y16_l, y16_h), ones );
2648+
//const __m256 ysum = _mm256_cvtepi32_ps( ysumi );
25952649

2596-
// Accumulate d1*m0*y
2597-
acc = _mm256_fmadd_ps( d1m0, ysum, acc );
2650+
//// Accumulate d1*m0*y
2651+
//acc = _mm256_fmadd_ps( d1m0, ysum, acc );
25982652
}
25992653

26002654
// Return horizontal sum of the acc vector
@@ -2603,7 +2657,7 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
26032657
res = _mm_add_ps( res, _mm_movehl_ps( res, res ) );
26042658
res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
26052659

2606-
sumf = _mm_cvtss_f32( res );
2660+
sumf = _mm_cvtss_f32( res ) + summs;
26072661
#else
26082662
// scalar
26092663
for (int i = 0; i < nb; i++) {

pocs/vdot/CMakeLists.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,8 @@ set(TARGET vdot)
22
add_executable(${TARGET} vdot.cpp)
33
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
44
target_compile_features(${TARGET} PRIVATE cxx_std_11)
5+
6+
set(TARGET q8dot)
7+
add_executable(${TARGET} q8dot.cpp)
8+
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
9+
target_compile_features(${TARGET} PRIVATE cxx_std_11)

0 commit comments

Comments
 (0)