Skip to content

Commit 6e0f0b6

Browse files
committed
ggml : Q8_0 unroll x2
1 parent 88618ab commit 6e0f0b6

File tree

1 file changed

+36
-18
lines changed

1 file changed

+36
-18
lines changed

ggml.c

Lines changed: 36 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -3079,32 +3079,50 @@ static void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void *
30793079
float32x4_t sumv0 = vdupq_n_f32(0.0f);
30803080
float32x4_t sumv1 = vdupq_n_f32(0.0f);
30813081

3082-
for (int i = 0; i < nb; ++i) {
3083-
const block_q8_0 * restrict x0 = &x[i];
3084-
const block_q8_0 * restrict y0 = &y[i];
3082+
for (int i = 0; i < nb; i += 2) {
3083+
const block_q8_0 * restrict x0 = &x[i + 0];
3084+
const block_q8_0 * restrict x1 = &x[i + 1];
3085+
const block_q8_0 * restrict y0 = &y[i + 0];
3086+
const block_q8_0 * restrict y1 = &y[i + 1];
30853087

3086-
const int8x16_t v0_0 = vld1q_s8(x0->qs);
3087-
const int8x16_t v0_1 = vld1q_s8(x0->qs + 16);
3088+
const int8x16_t x0_0 = vld1q_s8(x0->qs);
3089+
const int8x16_t x0_1 = vld1q_s8(x0->qs + 16);
3090+
const int8x16_t x1_0 = vld1q_s8(x1->qs);
3091+
const int8x16_t x1_1 = vld1q_s8(x1->qs + 16);
30883092

30893093
// load y
3090-
const int8x16_t v1_0 = vld1q_s8(y0->qs);
3091-
const int8x16_t v1_1 = vld1q_s8(y0->qs + 16);
3094+
const int8x16_t y0_0 = vld1q_s8(y0->qs);
3095+
const int8x16_t y0_1 = vld1q_s8(y0->qs + 16);
3096+
const int8x16_t y1_0 = vld1q_s8(y1->qs);
3097+
const int8x16_t y1_1 = vld1q_s8(y1->qs + 16);
30923098

30933099
#if defined(__ARM_FEATURE_DOTPROD)
30943100
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(
3095-
vdotq_s32(vdupq_n_s32(0), v0_0, v1_0),
3096-
vdotq_s32(vdupq_n_s32(0), v0_1, v1_1))), x0->d*y0->d);
3097-
#else
3098-
const int16x8_t p0l = vmull_s8(vget_low_s8 (v0_0), vget_low_s8 (v1_0));
3099-
const int16x8_t p0h = vmull_s8(vget_high_s8(v0_0), vget_high_s8(v1_0));
3100-
const int16x8_t p1l = vmull_s8(vget_low_s8 (v0_1), vget_low_s8 (v1_1));
3101-
const int16x8_t p1h = vmull_s8(vget_high_s8(v0_1), vget_high_s8(v1_1));
3101+
vdotq_s32(vdupq_n_s32(0), x0_0, y0_0),
3102+
vdotq_s32(vdupq_n_s32(0), x0_1, y0_1))), x0->d*y0->d);
31023103

3103-
const int32x4_t pl = vaddq_s32(vpaddlq_s16(p0l), vpaddlq_s16(p0h));
3104-
const int32x4_t ph = vaddq_s32(vpaddlq_s16(p1l), vpaddlq_s16(p1h));
3104+
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(
3105+
vdotq_s32(vdupq_n_s32(0), x1_0, y1_0),
3106+
vdotq_s32(vdupq_n_s32(0), x1_1, y1_1))), x1->d*y1->d);
31053107

3106-
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(pl), x0->d*y0->d);
3107-
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(ph), x0->d*y0->d);
3108+
#else
3109+
const int16x8_t p0_0 = vmull_s8(vget_low_s8 (x0_0), vget_low_s8 (y0_0));
3110+
const int16x8_t p0_1 = vmull_s8(vget_high_s8(x0_0), vget_high_s8(y0_0));
3111+
const int16x8_t p0_2 = vmull_s8(vget_low_s8 (x0_1), vget_low_s8 (y0_1));
3112+
const int16x8_t p0_3 = vmull_s8(vget_high_s8(x0_1), vget_high_s8(y0_1));
3113+
3114+
const int16x8_t p1_0 = vmull_s8(vget_low_s8 (x1_0), vget_low_s8 (y1_0));
3115+
const int16x8_t p1_1 = vmull_s8(vget_high_s8(x1_0), vget_high_s8(y1_0));
3116+
const int16x8_t p1_2 = vmull_s8(vget_low_s8 (x1_1), vget_low_s8 (y1_1));
3117+
const int16x8_t p1_3 = vmull_s8(vget_high_s8(x1_1), vget_high_s8(y1_1));
3118+
3119+
const int32x4_t p0 = vaddq_s32(vpaddlq_s16(p0_0), vpaddlq_s16(p0_1));
3120+
const int32x4_t p1 = vaddq_s32(vpaddlq_s16(p0_2), vpaddlq_s16(p0_3));
3121+
const int32x4_t p2 = vaddq_s32(vpaddlq_s16(p1_0), vpaddlq_s16(p1_1));
3122+
const int32x4_t p3 = vaddq_s32(vpaddlq_s16(p1_2), vpaddlq_s16(p1_3));
3123+
3124+
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(p0, p1)), x0->d*y0->d);
3125+
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(p2, p3)), x1->d*y1->d);
31083126
#endif
31093127
}
31103128

0 commit comments

Comments
 (0)