Skip to content

Commit 76b6b26

Browse files
committed
ggml : slight improvement of Q4_3 - no need for loop unrolling
1 parent 829c480 commit 76b6b26

File tree

1 file changed

+12
-36
lines changed

1 file changed

+12
-36
lines changed

ggml.c

Lines changed: 12 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -2875,77 +2875,53 @@ static void ggml_vec_dot_q4_3_q8_0(const int n, float * restrict s, const void *
28752875
float32x4_t sumv0 = vdupq_n_f32(0.0f);
28762876
float32x4_t sumv1 = vdupq_n_f32(0.0f);
28772877

2878-
float summs = 0.0f;
2878+
float summs0 = 0.0f;
2879+
float summs1 = 0.0f;
28792880

2880-
for (int i = 0; i < nb; i += 2) {
2881+
for (int i = 0; i < nb; ++i) {
28812882
const block_q4_3 * restrict x0_0 = &x[2*(i + 0) + 0];
28822883
const block_q4_3 * restrict x0_1 = &x[2*(i + 0) + 1];
2883-
const block_q4_3 * restrict x1_0 = &x[2*(i + 1) + 0];
2884-
const block_q4_3 * restrict x1_1 = &x[2*(i + 1) + 1];
28852884

28862885
const block_q8_0 * restrict y0 = &y[i + 0];
2887-
const block_q8_0 * restrict y1 = &y[i + 1];
28882886

2889-
summs += GGML_FP16_TO_FP32(x0_0->m) * y0->s0 + GGML_FP16_TO_FP32(x0_1->m) * y0->s1;
2890-
summs += GGML_FP16_TO_FP32(x1_0->m) * y1->s0 + GGML_FP16_TO_FP32(x1_1->m) * y1->s1;
2891-
2892-
const uint8x16_t m4b = vdupq_n_u8(0xf);
2893-
2894-
const float x0_0d = GGML_FP16_TO_FP32(x0_0->d);
2895-
const float x0_1d = GGML_FP16_TO_FP32(x0_1->d);
2896-
const float x1_0d = GGML_FP16_TO_FP32(x1_0->d);
2897-
const float x1_1d = GGML_FP16_TO_FP32(x1_1->d);
2887+
summs0 += GGML_FP16_TO_FP32(x0_0->m) * y0->s0;
2888+
summs1 += GGML_FP16_TO_FP32(x0_1->m) * y0->s1;
28982889

28992890
const uint8x16_t v0_0 = vcombine_u8(vld1_u8(x0_0->qs), vld1_u8(x0_1->qs));
2900-
const uint8x16_t v0_1 = vcombine_u8(vld1_u8(x1_0->qs), vld1_u8(x1_1->qs));
29012891

29022892
// 4-bit -> 8-bit
2903-
const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b));
2893+
const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, vdupq_n_u8(0xf)));
29042894
const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
2905-
const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
2906-
const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
29072895

29082896
// interleave
29092897
const int8x16_t v0_0lz = vzip1q_s8(v0_0l, v0_0h);
29102898
const int8x16_t v0_0hz = vzip2q_s8(v0_0l, v0_0h);
2911-
const int8x16_t v0_1lz = vzip1q_s8(v0_1l, v0_1h);
2912-
const int8x16_t v0_1hz = vzip2q_s8(v0_1l, v0_1h);
29132899

29142900
// load y
29152901
const int8x16_t v1_0l = vld1q_s8(y0->qs);
29162902
const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
2917-
const int8x16_t v1_1l = vld1q_s8(y1->qs);
2918-
const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
2903+
2904+
const float x0_0d = GGML_FP16_TO_FP32(x0_0->d);
2905+
const float x0_1d = GGML_FP16_TO_FP32(x0_1->d);
29192906

29202907
#if defined(__ARM_FEATURE_DOTPROD)
29212908
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0lz, v1_0l)), x0_0d*y0->d);
2922-
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0hz, v1_0h)), x0_1d*y0->d);
2923-
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_1lz, v1_1l)), x1_0d*y1->d);
2924-
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_1hz, v1_1h)), x1_1d*y1->d);
2909+
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0hz, v1_0h)), x0_1d*y0->d);
29252910
#else
29262911
const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lz), vget_low_s8 (v1_0l));
29272912
const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lz), vget_high_s8(v1_0l));
29282913
const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hz), vget_low_s8 (v1_0h));
29292914
const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hz), vget_high_s8(v1_0h));
29302915

2931-
const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1lz), vget_low_s8 (v1_1l));
2932-
const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1lz), vget_high_s8(v1_1l));
2933-
const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hz), vget_low_s8 (v1_1h));
2934-
const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hz), vget_high_s8(v1_1h));
2935-
29362916
const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
29372917
const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
2938-
const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
2939-
const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
29402918

29412919
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(pl0), x0_0d*y0->d);
2942-
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(ph0), x0_1d*y0->d);
2943-
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(pl1), x1_0d*y1->d);
2944-
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(ph1), x1_1d*y1->d);
2920+
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(ph0), x0_1d*y0->d);
29452921
#endif
29462922
}
29472923

2948-
sumf = vaddvq_f32(vaddq_f32(sumv0, sumv1)) + summs;
2924+
*s = vaddvq_f32(vaddq_f32(sumv0, sumv1)) + summs0 + summs1;
29492925
#elif defined(__AVX2__)
29502926
// Initialize accumulator with zeros
29512927
__m256 acc = _mm256_setzero_ps();

0 commit comments

Comments
 (0)