Skip to content

Commit 1b49d26

Browse files
committed
q4_0c: Arm Neon acceleration
Mostly copied from the q4_0 implementation
1 parent ab543dc commit 1b49d26

File tree

1 file changed

+94
-2
lines changed

1 file changed

+94
-2
lines changed

ggml.c

Lines changed: 94 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1758,7 +1758,37 @@ static void quantize_row_q8_0c(const float * restrict x, void * restrict vy, int
17581758
int8_t * restrict qs = vy;
17591759
float * restrict ds = (float *) ((uint8_t *) vy + nb*QK8_0C);
17601760

1761-
#if __AVX512F__
1761+
#if defined(__ARM_NEON)
1762+
for (int i = 0; i < nb; i++) {
1763+
float32x4_t srcv [8];
1764+
float32x4_t asrcv[8];
1765+
float32x4_t amaxv[8];
1766+
1767+
for (int l = 0; l < 8; l++) srcv[l] = vld1q_f32(x + i*32 + 4*l);
1768+
for (int l = 0; l < 8; l++) asrcv[l] = vabsq_f32(srcv[l]);
1769+
1770+
for (int l = 0; l < 4; l++) amaxv[2*l] = vmaxq_f32(asrcv[2*l], asrcv[2*l+1]);
1771+
for (int l = 0; l < 2; l++) amaxv[4*l] = vmaxq_f32(amaxv[4*l], amaxv[4*l+2]);
1772+
for (int l = 0; l < 1; l++) amaxv[8*l] = vmaxq_f32(amaxv[8*l], amaxv[8*l+4]);
1773+
1774+
const float amax = vmaxvq_f32(amaxv[0]);
1775+
1776+
const float d = amax / ((1 << 7) - 1);
1777+
const float id = d ? 1.0f/d : 0.0f;
1778+
1779+
ds[i] = d;
1780+
1781+
for (int l = 0; l < 8; l++) {
1782+
const float32x4_t v = vmulq_n_f32(srcv[l], id);
1783+
const int32x4_t vi = vcvtnq_s32_f32(v);
1784+
1785+
qs[i*QK8_0C + 4*l + 0] = vgetq_lane_s32(vi, 0);
1786+
qs[i*QK8_0C + 4*l + 1] = vgetq_lane_s32(vi, 1);
1787+
qs[i*QK8_0C + 4*l + 2] = vgetq_lane_s32(vi, 2);
1788+
qs[i*QK8_0C + 4*l + 3] = vgetq_lane_s32(vi, 3);
1789+
}
1790+
}
1791+
#elif defined(__AVX512F__)
17621792
for (int i = 0; i < nb; i++) {
17631793
const __m512 x0 = _mm512_loadu_ps( x + i*QK8_0C );
17641794
const __m512 x1 = _mm512_loadu_ps( x + i*QK8_0C + QK8_0C/2);
@@ -3095,7 +3125,69 @@ static void ggml_vec_dot_q4_0c_q8_0c(const int n, float * restrict s, const void
30953125

30963126
float sumf = 0.0;
30973127

3098-
#if __AVX512F__
3128+
#if defined(__ARM_NEON)
3129+
float32x4_t sumv0 = vdupq_n_f32(0.0f);
3130+
float32x4_t sumv1 = vdupq_n_f32(0.0f);
3131+
3132+
for (int i = 0; i < nb/2; i++) {
3133+
const int dst0 = i + i/2*2; // 0, 1, 4, 5, 8, 9, ...
3134+
const int dst1 = i + i/2*2 + 2; // 2, 3, 6, 7, 10, 11 ...
3135+
3136+
const uint8x16_t m4b = vdupq_n_u8(0xf);
3137+
const int8x16_t s8b = vdupq_n_s8(0x8);
3138+
3139+
const uint8x16_t v0_01l = vld1q_u8(&xqs[i*QK4_0]);
3140+
const uint8x16_t v0_01h = vld1q_u8(&xqs[i*QK4_0 + QK4_0/2]);
3141+
3142+
// 4-bit -> 8-bit
3143+
const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_01l, m4b));
3144+
const int8x16_t v0_0h = vreinterpretq_s8_u8(vandq_u8 (v0_01h, m4b));
3145+
const int8x16_t v0_1l = vreinterpretq_s8_u8(vshrq_n_u8(v0_01l, 4));
3146+
const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_01h, 4));
3147+
3148+
// sub 8
3149+
const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b);
3150+
const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b);
3151+
const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b);
3152+
const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b);
3153+
3154+
// load y
3155+
const int8x16_t v1_0l = vld1q_s8(&yqs[dst0*QK8_0C]);
3156+
const int8x16_t v1_0h = vld1q_s8(&yqs[dst0*QK8_0C + 16]);
3157+
const int8x16_t v1_1l = vld1q_s8(&yqs[dst1*QK8_0C]);
3158+
const int8x16_t v1_1h = vld1q_s8(&yqs[dst1*QK8_0C + 16]);
3159+
3160+
#if defined(__ARM_FEATURE_DOTPROD)
3161+
// dot product into int32x4_t
3162+
const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0l), v0_0hs, v1_0h);
3163+
const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1l), v0_1hs, v1_1h);
3164+
3165+
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), xds[dst0]*yds[dst0]);
3166+
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), xds[dst1]*yds[dst1]);
3167+
#else
3168+
const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0l));
3169+
const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0l));
3170+
const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hs), vget_low_s8 (v1_0h));
3171+
const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0h));
3172+
3173+
const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1ls), vget_low_s8 (v1_1l));
3174+
const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1l));
3175+
const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hs), vget_low_s8 (v1_1h));
3176+
const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1h));
3177+
3178+
const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
3179+
const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
3180+
const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
3181+
const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
3182+
3183+
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), xds[dst0]*yds[dst0]);
3184+
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), xds[dst1]*yds[dst1]);
3185+
#endif
3186+
}
3187+
3188+
sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
3189+
3190+
#elif defined(__AVX512F__)
30993191
// Initialize accumulator with zeros
31003192
__m512 acc = _mm512_setzero_ps();
31013193
for (int i = 0; i < nb; i += 4) {

0 commit comments

Comments
 (0)