Skip to content

Commit e116eb6

Browse files
committed
ggml : speed-up Q5_0 + Q5_1 at 4 threads
1 parent ffd76e1 commit e116eb6

File tree

1 file changed

+147
-76
lines changed

1 file changed

+147
-76
lines changed

ggml.c

Lines changed: 147 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -339,8 +339,9 @@ static float table_f32_f16[1 << 16];
339339
#define B7(c,s,n) B6(c,s,n ## c), B6(c,s,n ## s)
340340
#define B8(c,s ) B7(c,s, c), B7(c,s, s)
341341

342-
// precomputed tables for expanding 8bits to 8 bytes (shl 4)
343-
static const uint64_t table_b2b_u[1 << 8] = { B8(00, 10) };
342+
// precomputed tables for expanding 8bits to 8 bytes:
343+
static const uint64_t table_b2b_0[1 << 8] = { B8(00, 10) }; // ( b) << 4
344+
static const uint64_t table_b2b_1[1 << 8] = { B8(10, 00) }; // (!b) << 4
344345
#endif
345346

346347
// On ARM NEON, it's quicker to directly convert x -> x instead of calling into ggml_lookup_fp16_to_fp32,
@@ -2307,68 +2308,102 @@ static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void *
23072308
const block_q8_0 * restrict y = vy;
23082309

23092310
#if defined(__ARM_NEON)
2310-
float32x4_t sumv = vdupq_n_f32(0.0f);
2311+
float32x4_t sumv0 = vdupq_n_f32(0.0f);
2312+
float32x4_t sumv1 = vdupq_n_f32(0.0f);
23112313

2312-
uint64_t tmp[4];
2314+
uint32_t qh0;
2315+
uint32_t qh1;
23132316

2314-
for (int i = 0; i < nb; ++i) {
2317+
uint64_t tmp0[4];
2318+
uint64_t tmp1[4];
2319+
2320+
for (int i = 0; i < nb; i += 2) {
23152321
const block_q5_0 * restrict x0 = &x[i];
2322+
const block_q5_0 * restrict x1 = &x[i + 1];
23162323
const block_q8_0 * restrict y0 = &y[i];
2324+
const block_q8_0 * restrict y1 = &y[i + 1];
23172325

2318-
const uint8x16_t m4b = vdupq_n_u8(0x0F);
2319-
const int8x16_t s16b = vdupq_n_s8(0x10);
2326+
const uint8x16_t m4b = vdupq_n_u8(0x0F);
23202327

2321-
// extract the 5th bit
2322-
uint32_t qh;
2323-
memcpy(&qh, x0->qh, sizeof(qh));
2328+
// extract the 5th bit via lookup table ((!b) << 4)
2329+
memcpy(&qh0, x0->qh, sizeof(qh0));
2330+
memcpy(&qh1, x1->qh, sizeof(qh1));
23242331

2325-
tmp[0] = table_b2b_u[(qh >> 0) & 0xFF];
2326-
tmp[1] = table_b2b_u[(qh >> 8) & 0xFF];
2327-
tmp[2] = table_b2b_u[(qh >> 16) & 0xFF];
2328-
tmp[3] = table_b2b_u[(qh >> 24) ];
2332+
tmp0[0] = table_b2b_1[(qh0 >> 0) & 0xFF];
2333+
tmp0[1] = table_b2b_1[(qh0 >> 8) & 0xFF];
2334+
tmp0[2] = table_b2b_1[(qh0 >> 16) & 0xFF];
2335+
tmp0[3] = table_b2b_1[(qh0 >> 24) ];
23292336

2330-
const int8x16_t qhl = vld1q_s8((const int8_t *)(tmp + 0));
2331-
const int8x16_t qhh = vld1q_s8((const int8_t *)(tmp + 2));
2337+
tmp1[0] = table_b2b_1[(qh1 >> 0) & 0xFF];
2338+
tmp1[1] = table_b2b_1[(qh1 >> 8) & 0xFF];
2339+
tmp1[2] = table_b2b_1[(qh1 >> 16) & 0xFF];
2340+
tmp1[3] = table_b2b_1[(qh1 >> 24) ];
23322341

2333-
const uint8x16_t v0 = vld1q_u8(x0->qs);
2342+
const int8x16_t qhl0 = vld1q_s8((const int8_t *)(tmp0 + 0));
2343+
const int8x16_t qhh0 = vld1q_s8((const int8_t *)(tmp0 + 2));
2344+
const int8x16_t qhl1 = vld1q_s8((const int8_t *)(tmp1 + 0));
2345+
const int8x16_t qhh1 = vld1q_s8((const int8_t *)(tmp1 + 2));
2346+
2347+
const uint8x16_t v0_0 = vld1q_u8(x0->qs);
2348+
const uint8x16_t v0_1 = vld1q_u8(x1->qs);
23342349

23352350
// 4-bit -> 8-bit
2336-
const int8x16_t v0l = vreinterpretq_s8_u8(vandq_u8 (v0, m4b));
2337-
const int8x16_t v0h = vreinterpretq_s8_u8(vshrq_n_u8(v0, 4));
2351+
int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b));
2352+
int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
2353+
int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
2354+
int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
23382355

2339-
// add high bit and sub 16
2340-
const int8x16_t v0lf = vsubq_s8(vorrq_s8(v0l, qhl), s16b);
2341-
const int8x16_t v0hf = vsubq_s8(vorrq_s8(v0h, qhh), s16b);
2356+
// add high bit and sub 16 (equivalent to sub 0x10 when bit is zero)
2357+
const int8x16_t v0_0lf = vsubq_s8(v0_0l, qhl0);
2358+
const int8x16_t v0_0hf = vsubq_s8(v0_0h, qhh0);
2359+
const int8x16_t v0_1lf = vsubq_s8(v0_1l, qhl1);
2360+
const int8x16_t v0_1hf = vsubq_s8(v0_1h, qhh1);
23422361

23432362
// load y
2344-
const int8x16_t v1l = vld1q_s8(y0->qs);
2345-
const int8x16_t v1h = vld1q_s8(y0->qs + 16);
2363+
const int8x16_t v1_0l = vld1q_s8(y0->qs);
2364+
const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
2365+
const int8x16_t v1_1l = vld1q_s8(y1->qs);
2366+
const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
23462367

23472368
const float x0d = GGML_FP16_TO_FP32(x0->d);
2369+
const float x1d = GGML_FP16_TO_FP32(x1->d);
23482370

23492371
#if defined(__ARM_FEATURE_DOTPROD)
2350-
sumv = vmlaq_n_f32(sumv, vcvtq_f32_s32(vaddq_s32(
2351-
vdotq_s32(vdupq_n_s32(0), v0lf, v1l),
2352-
vdotq_s32(vdupq_n_s32(0), v0hf, v1h))), x0d*y0->d);
2372+
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(
2373+
vdotq_s32(vdupq_n_s32(0), v0_0lf, v1_0l),
2374+
vdotq_s32(vdupq_n_s32(0), v0_0hf, v1_0h))), x0d*y0->d);
2375+
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(
2376+
vdotq_s32(vdupq_n_s32(0), v0_1lf, v1_1l),
2377+
vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), x1d*y1->d);
23532378
#else
2354-
const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0lf), vget_low_s8 (v1l));
2355-
const int16x8_t pl0h = vmull_s8(vget_high_s8(v0lf), vget_high_s8(v1l));
2356-
const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0hf), vget_low_s8 (v1h));
2357-
const int16x8_t ph0h = vmull_s8(vget_high_s8(v0hf), vget_high_s8(v1h));
2379+
const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lf), vget_low_s8 (v1_0l));
2380+
const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lf), vget_high_s8(v1_0l));
2381+
const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hf), vget_low_s8 (v1_0h));
2382+
const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hf), vget_high_s8(v1_0h));
2383+
2384+
const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1lf), vget_low_s8 (v1_1l));
2385+
const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1lf), vget_high_s8(v1_1l));
2386+
const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hf), vget_low_s8 (v1_1h));
2387+
const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hf), vget_high_s8(v1_1h));
23582388

23592389
const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
23602390
const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
2391+
const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
2392+
const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
23612393

2362-
sumv = vmlaq_n_f32(sumv, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), x0d*y0->d);
2394+
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), x0d*y0->d);
2395+
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), x1d*y1->d);
23632396
#endif
23642397
}
23652398

2366-
*s = vaddvq_f32(sumv);
2399+
*s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
23672400
#elif defined(__wasm_simd128__)
23682401
v128_t sumv = wasm_f32x4_splat(0.0f);
23692402

2403+
uint32_t qh;
23702404
uint64_t tmp[4];
23712405

2406+
// TODO: check if unrolling this is better
23722407
for (int i = 0; i < nb; ++i) {
23732408
const block_q5_0 * restrict x0 = &x[i];
23742409
const block_q8_0 * restrict y0 = &y[i];
@@ -2377,13 +2412,12 @@ static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void *
23772412
const v128_t s16b = wasm_i8x16_splat(0x10);
23782413

23792414
// extract the 5th bit
2380-
uint32_t qh;
23812415
memcpy(&qh, x0->qh, sizeof(qh));
23822416

2383-
tmp[0] = table_b2b_u[(qh >> 0) & 0xFF];
2384-
tmp[1] = table_b2b_u[(qh >> 8) & 0xFF];
2385-
tmp[2] = table_b2b_u[(qh >> 16) & 0xFF];
2386-
tmp[3] = table_b2b_u[(qh >> 24) ];
2417+
tmp[0] = table_b2b_1[(qh >> 0) & 0xFF];
2418+
tmp[1] = table_b2b_1[(qh >> 8) & 0xFF];
2419+
tmp[2] = table_b2b_1[(qh >> 16) & 0xFF];
2420+
tmp[3] = table_b2b_1[(qh >> 24) ];
23872421

23882422
const v128_t qhl = wasm_v128_load(tmp + 0);
23892423
const v128_t qhh = wasm_v128_load(tmp + 2);
@@ -2395,8 +2429,8 @@ static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void *
23952429
const v128_t v0h = wasm_u8x16_shr(v0, 4);
23962430

23972431
// add high bit and sub 16
2398-
const v128_t v0lf = wasm_i8x16_sub(wasm_v128_or(v0l, qhl), s16b);
2399-
const v128_t v0hf = wasm_i8x16_sub(wasm_v128_or(v0h, qhh), s16b);
2432+
const v128_t v0lf = wasm_i8x16_sub(v0l, qhl);
2433+
const v128_t v0hf = wasm_i8x16_sub(v0h, qhh);
24002434

24012435
// load y
24022436
const v128_t v1l = wasm_v128_load(y0->qs);
@@ -2488,69 +2522,107 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void *
24882522
const block_q8_1 * restrict y = vy;
24892523

24902524
#if defined(__ARM_NEON)
2491-
float32x4_t sumv = vdupq_n_f32(0.0f);
2525+
float32x4_t sumv0 = vdupq_n_f32(0.0f);
2526+
float32x4_t sumv1 = vdupq_n_f32(0.0f);
24922527

2493-
float summs = 0.0f;
2528+
float summs0 = 0.0f;
2529+
float summs1 = 0.0f;
24942530

2495-
uint64_t tmp[4];
2531+
uint32_t qh0;
2532+
uint32_t qh1;
24962533

2497-
for (int i = 0; i < nb; ++i) {
2534+
uint64_t tmp0[4];
2535+
uint64_t tmp1[4];
2536+
2537+
for (int i = 0; i < nb; i += 2) {
24982538
const block_q5_1 * restrict x0 = &x[i];
2539+
const block_q5_1 * restrict x1 = &x[i + 1];
24992540
const block_q8_1 * restrict y0 = &y[i];
2541+
const block_q8_1 * restrict y1 = &y[i + 1];
25002542

2501-
summs += GGML_FP16_TO_FP32(x0->m) * (y0->s0 + y0->s1);
2543+
const uint8x16_t m4b = vdupq_n_u8(0x0F);
25022544

2503-
// extract the 5th bit
2504-
uint32_t qh;
2505-
memcpy(&qh, x0->qh, sizeof(qh));
2545+
summs0 += GGML_FP16_TO_FP32(x0->m) * (y0->s0 + y0->s1);
2546+
summs1 += GGML_FP16_TO_FP32(x1->m) * (y1->s0 + y1->s1);
2547+
2548+
// extract the 5th bit via lookup table ((b) << 4)
2549+
memcpy(&qh0, x0->qh, sizeof(qh0));
2550+
memcpy(&qh1, x1->qh, sizeof(qh1));
25062551

2507-
tmp[0] = table_b2b_u[(qh >> 0) & 0xFF];
2508-
tmp[1] = table_b2b_u[(qh >> 8) & 0xFF];
2509-
tmp[2] = table_b2b_u[(qh >> 16) & 0xFF];
2510-
tmp[3] = table_b2b_u[(qh >> 24) ];
2552+
tmp0[0] = table_b2b_0[(qh0 >> 0) & 0xFF];
2553+
tmp0[1] = table_b2b_0[(qh0 >> 8) & 0xFF];
2554+
tmp0[2] = table_b2b_0[(qh0 >> 16) & 0xFF];
2555+
tmp0[3] = table_b2b_0[(qh0 >> 24) ];
25112556

2512-
const int8x16_t qhl = vld1q_s8((const int8_t *)(tmp + 0));
2513-
const int8x16_t qhh = vld1q_s8((const int8_t *)(tmp + 2));
2557+
tmp1[0] = table_b2b_0[(qh1 >> 0) & 0xFF];
2558+
tmp1[1] = table_b2b_0[(qh1 >> 8) & 0xFF];
2559+
tmp1[2] = table_b2b_0[(qh1 >> 16) & 0xFF];
2560+
tmp1[3] = table_b2b_0[(qh1 >> 24) ];
25142561

2515-
const uint8x16_t v0 = vld1q_u8(x0->qs);
2562+
const int8x16_t qhl0 = vld1q_s8((const int8_t *)(tmp0 + 0));
2563+
const int8x16_t qhh0 = vld1q_s8((const int8_t *)(tmp0 + 2));
2564+
const int8x16_t qhl1 = vld1q_s8((const int8_t *)(tmp1 + 0));
2565+
const int8x16_t qhh1 = vld1q_s8((const int8_t *)(tmp1 + 2));
2566+
2567+
const uint8x16_t v0_0 = vld1q_u8(x0->qs);
2568+
const uint8x16_t v0_1 = vld1q_u8(x1->qs);
25162569

25172570
// 4-bit -> 8-bit
2518-
const int8x16_t v0l = vreinterpretq_s8_u8(vandq_u8 (v0, vdupq_n_u8(0x0F)));
2519-
const int8x16_t v0h = vreinterpretq_s8_u8(vshrq_n_u8(v0, 4));
2571+
const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b));
2572+
const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
2573+
const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
2574+
const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
25202575

2521-
// add
2522-
const int8x16_t v0lf = vorrq_s8(v0l, qhl);
2523-
const int8x16_t v0hf = vorrq_s8(v0h, qhh);
2576+
// add 5th bit
2577+
const int8x16_t v0_0lf = vorrq_s8(v0_0l, qhl0);
2578+
const int8x16_t v0_0hf = vorrq_s8(v0_0h, qhh0);
2579+
const int8x16_t v0_1lf = vorrq_s8(v0_1l, qhl1);
2580+
const int8x16_t v0_1hf = vorrq_s8(v0_1h, qhh1);
25242581

25252582
// load y
2526-
const int8x16_t v1l = vld1q_s8(y0->qs);
2527-
const int8x16_t v1h = vld1q_s8(y0->qs + 16);
2583+
const int8x16_t v1_0l = vld1q_s8(y0->qs);
2584+
const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
2585+
const int8x16_t v1_1l = vld1q_s8(y1->qs);
2586+
const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
25282587

25292588
const float x0d = GGML_FP16_TO_FP32(x0->d);
2589+
const float x1d = GGML_FP16_TO_FP32(x1->d);
25302590

25312591
#if defined(__ARM_FEATURE_DOTPROD)
2532-
sumv = vmlaq_n_f32(sumv, vcvtq_f32_s32(vaddq_s32(
2533-
vdotq_s32(vdupq_n_s32(0), v0lf, v1l),
2534-
vdotq_s32(vdupq_n_s32(0), v0hf, v1h))), x0d*y0->d);
2592+
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(
2593+
vdotq_s32(vdupq_n_s32(0), v0_0lf, v1_0l),
2594+
vdotq_s32(vdupq_n_s32(0), v0_0hf, v1_0h))), x0d*y0->d);
2595+
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(
2596+
vdotq_s32(vdupq_n_s32(0), v0_1lf, v1_1l),
2597+
vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), x1d*y1->d);
25352598
#else
2536-
const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0lf), vget_low_s8 (v1l));
2537-
const int16x8_t pl0h = vmull_s8(vget_high_s8(v0lf), vget_high_s8(v1l));
2538-
const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0hf), vget_low_s8 (v1h));
2539-
const int16x8_t ph0h = vmull_s8(vget_high_s8(v0hf), vget_high_s8(v1h));
2599+
const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lf), vget_low_s8 (v1_0l));
2600+
const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lf), vget_high_s8(v1_0l));
2601+
const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hf), vget_low_s8 (v1_0h));
2602+
const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hf), vget_high_s8(v1_0h));
2603+
2604+
const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1lf), vget_low_s8 (v1_1l));
2605+
const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1lf), vget_high_s8(v1_1l));
2606+
const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hf), vget_low_s8 (v1_1h));
2607+
const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hf), vget_high_s8(v1_1h));
25402608

25412609
const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
25422610
const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
2611+
const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
2612+
const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
25432613

2544-
sumv = vmlaq_n_f32(sumv, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), x0d*y0->d);
2614+
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), x0d*y0->d);
2615+
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), x1d*y1->d);
25452616
#endif
25462617
}
25472618

2548-
*s = vaddvq_f32(sumv) + summs;
2619+
*s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs0 + summs1;
25492620
#elif defined(__wasm_simd128__)
25502621
v128_t sumv = wasm_f32x4_splat(0.0f);
25512622

25522623
float summs = 0.0f;
25532624

2625+
uint32_t qh;
25542626
uint64_t tmp[4];
25552627

25562628
for (int i = 0; i < nb; ++i) {
@@ -2562,13 +2634,12 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void *
25622634
const v128_t m4b = wasm_i8x16_splat(0x0F);
25632635

25642636
// extract the 5th bit
2565-
uint32_t qh;
25662637
memcpy(&qh, x0->qh, sizeof(qh));
25672638

2568-
tmp[0] = table_b2b_u[(qh >> 0) & 0xFF];
2569-
tmp[1] = table_b2b_u[(qh >> 8) & 0xFF];
2570-
tmp[2] = table_b2b_u[(qh >> 16) & 0xFF];
2571-
tmp[3] = table_b2b_u[(qh >> 24) ];
2639+
tmp[0] = table_b2b_0[(qh >> 0) & 0xFF];
2640+
tmp[1] = table_b2b_0[(qh >> 8) & 0xFF];
2641+
tmp[2] = table_b2b_0[(qh >> 16) & 0xFF];
2642+
tmp[3] = table_b2b_0[(qh >> 24) ];
25722643

25732644
const v128_t qhl = wasm_v128_load(tmp + 0);
25742645
const v128_t qhh = wasm_v128_load(tmp + 2);

0 commit comments

Comments
 (0)