Skip to content

Commit 47857e5

Browse files
Ronsorggerganov
andauthored
Don't use vdotq_s32 if it's not available (abetlen#139)
* Don't use vdotq_s32 if it's not available `dotprod` extensions aren't available on some ARM CPUs (e.g. Raspberry Pi 4), so check for them and only use them if they're available. Reintroduces the code removed in 84d9015 if `__ARM_FEATURE_DOTPROD` isn't defined. * Update ggml.c --------- Co-authored-by: Georgi Gerganov <[email protected]>
1 parent 60f819a commit 47857e5

File tree

1 file changed

+32
-1
lines changed

1 file changed

+32
-1
lines changed

ggml.c

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1359,8 +1359,8 @@ inline static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void
13591359
const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b);
13601360
const int8x16_t v1_1hs = vsubq_s8(v1_1h, s8b);
13611361

1362+
#if defined(__ARM_FEATURE_DOTPROD)
13621363
// dot product into int16x8_t
1363-
// assume that vdotq_s32 is always available, if not, should check for __ARM_FEATURE_DOTPROD
13641364
int32x4_t p_0 = vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0ls);
13651365
int32x4_t p_1 = vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1ls);
13661366

@@ -1374,6 +1374,37 @@ inline static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void
13741374
#else
13751375
sum0 += d0_0*d1_0*(vgetq_lane_s32(p_0, 0) + vgetq_lane_s32(p_0, 1) + vgetq_lane_s32(p_0, 2) + vgetq_lane_s32(p_0, 3));
13761376
sum1 += d0_1*d1_1*(vgetq_lane_s32(p_1, 0) + vgetq_lane_s32(p_1, 1) + vgetq_lane_s32(p_1, 2) + vgetq_lane_s32(p_1, 3));
1377+
#endif
1378+
#else
1379+
const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0ls));
1380+
const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0ls));
1381+
1382+
const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hs), vget_low_s8 (v1_0hs));
1383+
const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0hs));
1384+
1385+
const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1ls), vget_low_s8 (v1_1ls));
1386+
const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1ls));
1387+
1388+
const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hs), vget_low_s8 (v1_1hs));
1389+
const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1hs));
1390+
1391+
const int16x8_t pl_0 = vaddq_s16(pl0l, pl0h);
1392+
const int16x8_t ph_0 = vaddq_s16(ph0l, ph0h);
1393+
1394+
const int16x8_t pl_1 = vaddq_s16(pl1l, pl1h);
1395+
const int16x8_t ph_1 = vaddq_s16(ph1l, ph1h);
1396+
1397+
const int16x8_t p_0 = vaddq_s16(pl_0, ph_0);
1398+
const int16x8_t p_1 = vaddq_s16(pl_1, ph_1);
1399+
1400+
// scalar
1401+
#if defined(__ARM_FEATURE_QRDMX)
1402+
sum0 += d0_0*d1_0*vaddvq_s16(p_0);
1403+
sum1 += d0_1*d1_1*vaddvq_s16(p_1);
1404+
#else
1405+
sum0 += d0_0*d1_0*(vgetq_lane_s16(p_0, 0) + vgetq_lane_s16(p_0, 1) + vgetq_lane_s16(p_0, 2) + vgetq_lane_s16(p_0, 3) + vgetq_lane_s16(p_0, 4) + vgetq_lane_s16(p_0, 5) + vgetq_lane_s16(p_0, 6) + vgetq_lane_s16(p_0, 7));
1406+
sum1 += d0_1*d1_1*(vgetq_lane_s16(p_1, 0) + vgetq_lane_s16(p_1, 1) + vgetq_lane_s16(p_1, 2) + vgetq_lane_s16(p_1, 3) + vgetq_lane_s16(p_1, 4) + vgetq_lane_s16(p_1, 5) + vgetq_lane_s16(p_1, 6) + vgetq_lane_s16(p_1, 7));
1407+
#endif
13771408
#endif
13781409
}
13791410

0 commit comments

Comments
 (0)