Skip to content

Commit 6496b79

Browse files
committed
ggml : use q4_0_q8_0 and q4_2_q8_0
1 parent d8bf720 commit 6496b79

File tree

1 file changed

+24
-21
lines changed

1 file changed

+24
-21
lines changed

ggml.c

Lines changed: 24 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1825,9 +1825,9 @@ static void dequantize_row_q8_0(const void * restrict vx, float * restrict y, in
18251825
}
18261826
}
18271827

1828-
static void ggml_vec_dot_q4_0_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
1828+
static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
18291829
static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
1830-
static void ggml_vec_dot_q4_2_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
1830+
static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
18311831
static void ggml_vec_dot_q4_3_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
18321832
static void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
18331833

@@ -1837,7 +1837,7 @@ static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = {
18371837
.quantize_row_q = quantize_row_q4_0,
18381838
.quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_0_reference,
18391839
.quantize_row_q_dot = quantize_row_q8_1,
1840-
.vec_dot_q = ggml_vec_dot_q4_0_q8_1,
1840+
.vec_dot_q = ggml_vec_dot_q4_0_q8_0,
18411841
},
18421842
[GGML_TYPE_Q4_1] = {
18431843
.dequantize_row_q = dequantize_row_q4_1,
@@ -1851,7 +1851,7 @@ static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = {
18511851
.quantize_row_q = quantize_row_q4_2,
18521852
.quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_2_reference,
18531853
.quantize_row_q_dot = quantize_row_q8_1,
1854-
.vec_dot_q = ggml_vec_dot_q4_2_q8_1,
1854+
.vec_dot_q = ggml_vec_dot_q4_2_q8_0,
18551855
},
18561856
[GGML_TYPE_Q4_3] = {
18571857
.dequantize_row_q = dequantize_row_q4_3,
@@ -2475,7 +2475,7 @@ inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t
24752475
*s = sumf;
24762476
}
24772477

2478-
static void ggml_vec_dot_q4_0_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
2478+
static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
24792479
const int nb = n / QK8_1;
24802480

24812481
assert(n % QK8_1 == 0);
@@ -2488,17 +2488,14 @@ static void ggml_vec_dot_q4_0_q8_1(const int n, float * restrict s, const void *
24882488
float32x4_t sumv0 = vdupq_n_f32(0.0f);
24892489
float32x4_t sumv1 = vdupq_n_f32(0.0f);
24902490

2491-
float sum8 = 0;
2492-
24932491
for (int i = 0; i < nb; i += 2) {
24942492
const block_q4_0 * restrict x0 = &x[i + 0];
24952493
const block_q4_0 * restrict x1 = &x[i + 1];
24962494
const block_q8_1 * restrict y0 = &y[i + 0];
24972495
const block_q8_1 * restrict y1 = &y[i + 1];
24982496

2499-
sum8 += x0->d * (y0->s0 + y0->s1) + x1->d * (y1->s0 + y1->s1);
2500-
25012497
const uint8x16_t m4b = vdupq_n_u8(0xf);
2498+
const int8x16_t s8b = vdupq_n_s8(0x8);
25022499

25032500
const uint8x16_t v0_0 = vld1q_u8(x0->qs);
25042501
const uint8x16_t v0_1 = vld1q_u8(x1->qs);
@@ -2509,6 +2506,12 @@ static void ggml_vec_dot_q4_0_q8_1(const int n, float * restrict s, const void *
25092506
const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
25102507
const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
25112508

2509+
// sub 8
2510+
const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b);
2511+
const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b);
2512+
const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b);
2513+
const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b);
2514+
25122515
// load y
25132516
const int8x16_t v1_0l = vld1q_s8(y0->qs);
25142517
const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
@@ -2523,21 +2526,21 @@ static void ggml_vec_dot_q4_0_q8_1(const int n, float * restrict s, const void *
25232526

25242527
#if defined(__ARM_FEATURE_DOTPROD)
25252528
// dot product into int32x4_t
2526-
const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0l, v1_0ls), v0_0h, v1_0hs);
2527-
const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1l, v1_1ls), v0_1h, v1_1hs);
2529+
const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0ls), v0_0hs, v1_0hs);
2530+
const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1ls), v0_1hs, v1_1hs);
25282531

25292532
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), x0->d*y0->d);
25302533
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), x1->d*y1->d);
25312534
#else
2532-
const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0l), vget_low_s8 (v1_0ls));
2533-
const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0l), vget_high_s8(v1_0ls));
2534-
const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0h), vget_low_s8 (v1_0hs));
2535-
const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0h), vget_high_s8(v1_0hs));
2535+
const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0ls));
2536+
const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0ls));
2537+
const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hs), vget_low_s8 (v1_0hs));
2538+
const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0hs));
25362539

2537-
const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1l), vget_low_s8 (v1_1ls));
2538-
const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1l), vget_high_s8(v1_1ls));
2539-
const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1h), vget_low_s8 (v1_1hs));
2540-
const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1h), vget_high_s8(v1_1hs));
2540+
const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1ls), vget_low_s8 (v1_1ls));
2541+
const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1ls));
2542+
const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hs), vget_low_s8 (v1_1hs));
2543+
const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1hs));
25412544

25422545
const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
25432546
const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
@@ -2549,7 +2552,7 @@ static void ggml_vec_dot_q4_0_q8_1(const int n, float * restrict s, const void *
25492552
#endif
25502553
}
25512554

2552-
*s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) - 8 * sum8;
2555+
*s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
25532556
#elif defined(__AVX2__)
25542557
// Initialize accumulator with zeros
25552558
__m256 acc = _mm256_setzero_ps();
@@ -2775,7 +2778,7 @@ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void *
27752778
#endif
27762779
}
27772780

2778-
static void ggml_vec_dot_q4_2_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
2781+
static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
27792782
const int nb = n / QK8_1;
27802783

27812784
assert(n % QK8_1 == 0);

0 commit comments

Comments
 (0)