Skip to content

Commit 2ae3164

Browse files
committed
ggml : speed-up q4_1 ARM_NEON by ~5%
1 parent 9190e8e commit 2ae3164

File tree

1 file changed

+78
-43
lines changed

1 file changed

+78
-43
lines changed

ggml.c

Lines changed: 78 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -491,6 +491,32 @@ static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 )
491491
}
492492
#endif
493493

494+
#if __ARM_NEON
495+
#if !defined(__ARM_FEATURE_QRDMX)
496+
497+
inline static int16_t vaddvq_s16(int16x8_t v) {
498+
const int16x4_t v1 = vadd_s16(vget_low_s16(v), vget_high_s16(v));
499+
return vaddv_s16(v1);
500+
}
501+
502+
inline static uint16_t vaddvq_u16(uint16x8_t v) {
503+
const uint16x4_t v1 = vadd_u16(vget_low_u16(v), vget_high_u16(v));
504+
return vaddv_u16(v1);
505+
}
506+
507+
inline static int32_t vaddvq_s32(int32x4_t v) {
508+
const int32x2_t v1 = vadd_s32(vget_low_s32(v), vget_high_s32(v));
509+
return vaddv_s32(v1);
510+
}
511+
512+
inline static float vaddvq_f32(float32x4_t v) {
513+
const float32x2_t v1 = vadd_f32(vget_low_f32(v), vget_high_f32(v));
514+
return vaddv_f32(v1);
515+
}
516+
517+
#endif
518+
#endif
519+
494520
// method 5
495521
// blocks of QK elements
496522
// represented with a single float (delta) and QK/2 8-bit ints (i.e QK 4-bit signed integer factors)
@@ -1218,15 +1244,7 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
12181244
#define GGML_F32x4_FMA(a, b, c) vfmaq_f32(a, b, c)
12191245
#define GGML_F32x4_ADD vaddq_f32
12201246
#define GGML_F32x4_MUL vmulq_f32
1221-
#if defined(__ARM_FEATURE_QRDMX)
1222-
#define GGML_F32x4_REDUCE_ONE(x) vaddvq_f32(x)
1223-
#else
1224-
#define GGML_F32x4_REDUCE_ONE(x) \
1225-
(vgetq_lane_f32(x, 0) + \
1226-
vgetq_lane_f32(x, 1) + \
1227-
vgetq_lane_f32(x, 2) + \
1228-
vgetq_lane_f32(x, 3))
1229-
#endif
1247+
#define GGML_F32x4_REDUCE_ONE(x) vaddvq_f32(x)
12301248
#define GGML_F32x4_REDUCE(res, x) \
12311249
{ \
12321250
for (int i = 0; i < GGML_F32_ARR/2; ++i) { \
@@ -1849,55 +1867,43 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
18491867
// 4-bit -> 8-bit
18501868
const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8(v0_0, m4b));
18511869
const int8x16_t v1_0l = vreinterpretq_s8_u8(vandq_u8(v1_0, m4b));
1852-
18531870
const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
18541871
const int8x16_t v1_0h = vreinterpretq_s8_u8(vshrq_n_u8(v1_0, 4));
18551872

18561873
const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8(v0_1, m4b));
18571874
const int8x16_t v1_1l = vreinterpretq_s8_u8(vandq_u8(v1_1, m4b));
1858-
18591875
const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
18601876
const int8x16_t v1_1h = vreinterpretq_s8_u8(vshrq_n_u8(v1_1, 4));
18611877

18621878
// sub 8
18631879
const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b);
18641880
const int8x16_t v1_0ls = vsubq_s8(v1_0l, s8b);
1865-
18661881
const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b);
18671882
const int8x16_t v1_0hs = vsubq_s8(v1_0h, s8b);
18681883

18691884
const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b);
18701885
const int8x16_t v1_1ls = vsubq_s8(v1_1l, s8b);
1871-
18721886
const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b);
18731887
const int8x16_t v1_1hs = vsubq_s8(v1_1h, s8b);
18741888

18751889
#if defined(__ARM_FEATURE_DOTPROD)
1876-
// dot product into int16x8_t
1890+
// dot product into int32x4_t
18771891
int32x4_t p_0 = vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0ls);
18781892
int32x4_t p_1 = vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1ls);
18791893

18801894
p_0 = vdotq_s32(p_0, v0_0hs, v1_0hs);
18811895
p_1 = vdotq_s32(p_1, v0_1hs, v1_1hs);
18821896

1883-
// scalar
1884-
#if defined(__ARM_FEATURE_QRDMX)
1885-
sum0 += x0->d * y0->d * vaddvq_s32(p_0);
1886-
sum1 += x1->d * y1->d * vaddvq_s32(p_1);
1887-
#else
1888-
sum0 += x0->d * y0->d * (vgetq_lane_s32(p_0, 0) + vgetq_lane_s32(p_0, 1) + vgetq_lane_s32(p_0, 2) + vgetq_lane_s32(p_0, 3));
1889-
sum1 += x1->d * y1->d * (vgetq_lane_s32(p_1, 0) + vgetq_lane_s32(p_1, 1) + vgetq_lane_s32(p_1, 2) + vgetq_lane_s32(p_1, 3));
1890-
#endif
1897+
sum0 += x0->d*y0->d*vaddvq_s32(p_0);
1898+
sum1 += x1->d*y1->d*vaddvq_s32(p_1);
18911899
#else
18921900
const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0ls));
18931901
const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0ls));
1894-
18951902
const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hs), vget_low_s8 (v1_0hs));
18961903
const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0hs));
18971904

18981905
const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1ls), vget_low_s8 (v1_1ls));
18991906
const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1ls));
1900-
19011907
const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hs), vget_low_s8 (v1_1hs));
19021908
const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1hs));
19031909

@@ -1910,14 +1916,8 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
19101916
const int16x8_t p_0 = vaddq_s16(pl_0, ph_0);
19111917
const int16x8_t p_1 = vaddq_s16(pl_1, ph_1);
19121918

1913-
// scalar
1914-
#if defined(__ARM_FEATURE_QRDMX)
1915-
sum0 += x0->d * y0->d * vaddvq_s16(p_0);
1916-
sum1 += x1->d * y1->d * vaddvq_s16(p_1);
1917-
#else
1918-
sum0 += x0->d * y0->d * (vgetq_lane_s16(p_0, 0) + vgetq_lane_s16(p_0, 1) + vgetq_lane_s16(p_0, 2) + vgetq_lane_s16(p_0, 3) + vgetq_lane_s16(p_0, 4) + vgetq_lane_s16(p_0, 5) + vgetq_lane_s16(p_0, 6) + vgetq_lane_s16(p_0, 7));
1919-
sum1 += x1->d * y1->d * (vgetq_lane_s16(p_1, 0) + vgetq_lane_s16(p_1, 1) + vgetq_lane_s16(p_1, 2) + vgetq_lane_s16(p_1, 3) + vgetq_lane_s16(p_1, 4) + vgetq_lane_s16(p_1, 5) + vgetq_lane_s16(p_1, 6) + vgetq_lane_s16(p_1, 7));
1920-
#endif
1919+
sum0 += x0->d*y0->d*vaddvq_s16(p_0);
1920+
sum1 += x1->d*y1->d*vaddvq_s16(p_1);
19211921
#endif
19221922
}
19231923

@@ -2265,36 +2265,71 @@ static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * rest
22652265
float sum10 = 0.0f;
22662266
float sum11 = 0.0f;
22672267

2268-
for (int i = 0; i < nb; ++i) {
2268+
for (int i = 0; i < nb; i += 2) {
22692269
const block_q4_1 * restrict x0 = &x[i + 0];
22702270
const block_q4_1 * restrict y0 = &y[i + 0];
2271+
const block_q4_1 * restrict x1 = &x[i + 1];
2272+
const block_q4_1 * restrict y1 = &y[i + 1];
22712273

22722274
const uint8x16_t m4b = vdupq_n_u8(0xf);
22732275

22742276
const uint8x16_t v0_0 = vld1q_u8(x0->qs);
22752277
const uint8x16_t v1_0 = vld1q_u8(y0->qs);
2278+
const uint8x16_t v0_1 = vld1q_u8(x1->qs);
2279+
const uint8x16_t v1_1 = vld1q_u8(y1->qs);
22762280

2277-
// and with 0xf
2281+
// 4-bit -> 8-bit
22782282
const uint8x16_t v0_0l = vandq_u8(v0_0, m4b);
22792283
const uint8x16_t v1_0l = vandq_u8(v1_0, m4b);
2280-
22812284
const uint8x16_t v0_0h = vshrq_n_u8(v0_0, 4);
22822285
const uint8x16_t v1_0h = vshrq_n_u8(v1_0, 4);
22832286

2284-
// dot product into uint16x8_t
2287+
const uint8x16_t v0_1l = vandq_u8(v0_1, m4b);
2288+
const uint8x16_t v1_1l = vandq_u8(v1_1, m4b);
2289+
const uint8x16_t v0_1h = vshrq_n_u8(v0_1, 4);
2290+
const uint8x16_t v1_1h = vshrq_n_u8(v1_1, 4);
2291+
2292+
sum00 += x0->m*y0->m;
2293+
sum01 += y0->m*x0->d*(vaddvq_u8(v0_0l) + vaddvq_u8(v0_0h));
2294+
sum10 += x0->m*y0->d*(vaddvq_u8(v1_0l) + vaddvq_u8(v1_0h));
2295+
2296+
sum00 += x1->m*y1->m;
2297+
sum01 += y1->m*x1->d*(vaddvq_u8(v0_1l) + vaddvq_u8(v0_1h));
2298+
sum10 += x1->m*y1->d*(vaddvq_u8(v1_1l) + vaddvq_u8(v1_1h));
2299+
2300+
#if defined(__ARM_FEATURE_DOTPROD)
2301+
// dot product into int32x4_t
2302+
int32x4_t p_0 = vdotq_s32(vdupq_n_s32(0), v0_0l, v1_0l);
2303+
int32x4_t p_1 = vdotq_s32(vdupq_n_s32(0), v0_1l, v1_1l);
2304+
2305+
p_0 = vdotq_s32(p_0, v0_0h, v1_0h);
2306+
p_1 = vdotq_s32(p_1, v0_1h, v1_1h);
2307+
2308+
sum11 += x0->d*y0->d*vaddvq_s32(p_0);
2309+
sum11 += x1->d*y1->d*vaddvq_s32(p_1);
2310+
#else
22852311
const uint16x8_t pl0l = vmull_u8(vget_low_u8 (v0_0l), vget_low_u8 (v1_0l));
22862312
const uint16x8_t pl0h = vmull_u8(vget_high_u8(v0_0l), vget_high_u8(v1_0l));
2287-
22882313
const uint16x8_t ph0l = vmull_u8(vget_low_u8 (v0_0h), vget_low_u8 (v1_0h));
22892314
const uint16x8_t ph0h = vmull_u8(vget_high_u8(v0_0h), vget_high_u8(v1_0h));
22902315

2291-
const uint16x8_t pl0 = vaddq_u16(pl0l, pl0h);
2292-
const uint16x8_t ph0 = vaddq_u16(ph0l, ph0h);
2316+
const uint16x8_t pl1l = vmull_u8(vget_low_s8 (v0_1l), vget_low_u8 (v1_1l));
2317+
const uint16x8_t pl1h = vmull_u8(vget_high_s8(v0_1l), vget_high_u8(v1_1l));
2318+
const uint16x8_t ph1l = vmull_u8(vget_low_s8 (v0_1h), vget_low_u8 (v1_1h));
2319+
const uint16x8_t ph1h = vmull_u8(vget_high_s8(v0_1h), vget_high_u8(v1_1h));
22932320

2294-
sum00 += x0->m*y0->m;
2295-
sum01 += y0->m*x0->d*(vaddvq_u8(v0_0l) + vaddvq_u8(v0_0h));
2296-
sum10 += x0->m*y0->d*(vaddvq_u8(v1_0l) + vaddvq_u8(v1_0h));
2297-
sum11 += x0->d*y0->d*vaddvq_u16(vaddq_u16(pl0, ph0));
2321+
const uint16x8_t pl_0 = vaddq_u16(pl0l, pl0h);
2322+
const uint16x8_t ph_0 = vaddq_u16(ph0l, ph0h);
2323+
2324+
const uint16x8_t pl_1 = vaddq_u16(pl1l, pl1h);
2325+
const uint16x8_t ph_1 = vaddq_u16(ph1l, ph1h);
2326+
2327+
const uint16x8_t p_0 = vaddq_u16(pl_0, ph_0);
2328+
const uint16x8_t p_1 = vaddq_u16(pl_1, ph_1);
2329+
2330+
sum11 += x0->d*y0->d*vaddvq_u16(p_0);
2331+
sum11 += x1->d*y1->d*vaddvq_u16(p_1);
2332+
#endif
22982333
}
22992334

23002335
sumf = QK*sum00 + sum01 + sum10 + sum11;

0 commit comments

Comments
 (0)