Skip to content

Commit d8bf720

Browse files
committed
ggml : finalize Q8_0 implementation
1 parent 79cfdf5 commit d8bf720

File tree

1 file changed

+19
-31
lines changed

1 file changed

+19
-31
lines changed

ggml.c

Lines changed: 19 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1829,7 +1829,7 @@ static void ggml_vec_dot_q4_0_q8_1(const int n, float * restrict s, const void *
18291829
static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
18301830
static void ggml_vec_dot_q4_2_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
18311831
static void ggml_vec_dot_q4_3_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
1832-
static void ggml_vec_dot_q8_0_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
1832+
static void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
18331833

18341834
static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = {
18351835
[GGML_TYPE_Q4_0] = {
@@ -1864,8 +1864,8 @@ static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = {
18641864
.dequantize_row_q = dequantize_row_q8_0,
18651865
.quantize_row_q = quantize_row_q8_0,
18661866
.quantize_row_q_reference = (quantize_row_q_t) quantize_row_q8_0_reference,
1867-
.quantize_row_q_dot = quantize_row_q8_1,
1868-
.vec_dot_q = ggml_vec_dot_q8_0_q8_1,
1867+
.quantize_row_q_dot = quantize_row_q8_0,
1868+
.vec_dot_q = ggml_vec_dot_q8_0_q8_0,
18691869
},
18701870
[GGML_TYPE_Q8_1] = {
18711871
.dequantize_row_q = NULL, // TODO
@@ -3062,23 +3062,23 @@ static void ggml_vec_dot_q4_3_q8_1(const int n, float * restrict s, const void *
30623062
#endif
30633063
}
30643064

3065-
static void ggml_vec_dot_q8_0_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
3066-
const int nb = n / QK8_1;
3065+
static void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
3066+
const int nb = n / QK8_0;
30673067

3068-
assert(n % QK8_1 == 0);
3068+
assert(n % QK8_0 == 0);
30693069
assert(nb % 2 == 0);
3070-
assert(QK8_1 == QK8_0);
3070+
assert(QK8_0 == QK8_0);
30713071

30723072
const block_q8_0 * restrict x = vx;
3073-
const block_q8_1 * restrict y = vy;
3073+
const block_q8_0 * restrict y = vy;
30743074

3075-
#if defined(__ARM_NEON_XXX)
3075+
#if defined(__ARM_NEON)
30763076
float32x4_t sumv0 = vdupq_n_f32(0.0f);
30773077
float32x4_t sumv1 = vdupq_n_f32(0.0f);
30783078

30793079
for (int i = 0; i < nb; ++i) {
30803080
const block_q8_0 * restrict x0 = &x[i];
3081-
const block_q8_1 * restrict y0 = &y[i];
3081+
const block_q8_0 * restrict y0 = &y[i];
30823082

30833083
const int8x16_t v0_0 = vld1q_s8(x0->qs);
30843084
const int8x16_t v0_1 = vld1q_s8(x0->qs + 16);
@@ -3096,28 +3096,16 @@ static void ggml_vec_dot_q8_0_q8_1(const int n, float * restrict s, const void *
30963096
vdotq_s32(vdupq_n_s32(0), v0_0, v1_1),
30973097
vdotq_s32(vdupq_n_s32(0), v0_1, v1_0))), x0->d*y0->d);
30983098
#else
3099-
const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lz), vget_low_s8 (v1_0l));
3100-
const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lz), vget_high_s8(v1_0l));
3101-
const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hz), vget_low_s8 (v1_0h));
3102-
const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hz), vget_high_s8(v1_0h));
3099+
const int16x8_t p0l = vmull_s8(vget_low_s8 (v0_0), vget_low_s8 (v1_0));
3100+
const int16x8_t p0h = vmull_s8(vget_high_s8(v0_0), vget_high_s8(v1_0));
3101+
const int16x8_t p1l = vmull_s8(vget_low_s8 (v0_1), vget_low_s8 (v1_1));
3102+
const int16x8_t p1h = vmull_s8(vget_high_s8(v0_1), vget_high_s8(v1_1));
31033103

3104-
const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1lz), vget_low_s8 (v1_1l));
3105-
const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1lz), vget_high_s8(v1_1l));
3106-
const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hz), vget_low_s8 (v1_1h));
3107-
const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hz), vget_high_s8(v1_1h));
3104+
const int32x4_t pl = vaddq_s32(vpaddlq_s16(p0l), vpaddlq_s16(p0h));
3105+
const int32x4_t ph = vaddq_s32(vpaddlq_s16(p1l), vpaddlq_s16(p1h));
31083106

3109-
const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
3110-
const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
3111-
const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
3112-
const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
3113-
3114-
sumv0 = vmlaq_n_f32(sumv0, vaddq_f32(
3115-
vmulq_n_f32(vcvtq_f32_s32(pl0), GGML_FP16_TO_FP32(x0_0->d)),
3116-
vmulq_n_f32(vcvtq_f32_s32(ph0), GGML_FP16_TO_FP32(x0_1->d))), y0->d);
3117-
3118-
sumv1 = vmlaq_n_f32(sumv1, vaddq_f32(
3119-
vmulq_n_f32(vcvtq_f32_s32(pl1), GGML_FP16_TO_FP32(x1_0->d)),
3120-
vmulq_n_f32(vcvtq_f32_s32(ph1), GGML_FP16_TO_FP32(x1_1->d))), y1->d);
3107+
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(pl), x0->d*y0->d);
3108+
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(ph), x0->d*y0->d);
31213109
#endif
31223110
}
31233111

@@ -3132,7 +3120,7 @@ static void ggml_vec_dot_q8_0_q8_1(const int n, float * restrict s, const void *
31323120

31333121
int sumi = 0;
31343122

3135-
for (int j = 0; j < QK8_1; j++) {
3123+
for (int j = 0; j < QK8_0; j++) {
31363124
const int v0 = x0[j];
31373125
const int v1 = y0[j];
31383126

0 commit comments

Comments
 (0)