@@ -1359,8 +1359,8 @@ inline static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void
1359
1359
const int8x16_t v0_1hs = vsubq_s8 (v0_1h , s8b );
1360
1360
const int8x16_t v1_1hs = vsubq_s8 (v1_1h , s8b );
1361
1361
1362
+ #if defined(__ARM_FEATURE_DOTPROD )
1362
1363
// dot product into int16x8_t
1363
- // assume that vdotq_s32 is always available, if not, should check for __ARM_FEATURE_DOTPROD
1364
1364
int32x4_t p_0 = vdotq_s32 (vdupq_n_s32 (0 ), v0_0ls , v1_0ls );
1365
1365
int32x4_t p_1 = vdotq_s32 (vdupq_n_s32 (0 ), v0_1ls , v1_1ls );
1366
1366
@@ -1374,6 +1374,37 @@ inline static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void
1374
1374
#else
1375
1375
sum0 += d0_0 * d1_0 * (vgetq_lane_s32 (p_0 , 0 ) + vgetq_lane_s32 (p_0 , 1 ) + vgetq_lane_s32 (p_0 , 2 ) + vgetq_lane_s32 (p_0 , 3 ));
1376
1376
sum1 += d0_1 * d1_1 * (vgetq_lane_s32 (p_1 , 0 ) + vgetq_lane_s32 (p_1 , 1 ) + vgetq_lane_s32 (p_1 , 2 ) + vgetq_lane_s32 (p_1 , 3 ));
1377
+ #endif
1378
+ #else
1379
+ const int16x8_t pl0l = vmull_s8 (vget_low_s8 (v0_0ls ), vget_low_s8 (v1_0ls ));
1380
+ const int16x8_t pl0h = vmull_s8 (vget_high_s8 (v0_0ls ), vget_high_s8 (v1_0ls ));
1381
+
1382
+ const int16x8_t ph0l = vmull_s8 (vget_low_s8 (v0_0hs ), vget_low_s8 (v1_0hs ));
1383
+ const int16x8_t ph0h = vmull_s8 (vget_high_s8 (v0_0hs ), vget_high_s8 (v1_0hs ));
1384
+
1385
+ const int16x8_t pl1l = vmull_s8 (vget_low_s8 (v0_1ls ), vget_low_s8 (v1_1ls ));
1386
+ const int16x8_t pl1h = vmull_s8 (vget_high_s8 (v0_1ls ), vget_high_s8 (v1_1ls ));
1387
+
1388
+ const int16x8_t ph1l = vmull_s8 (vget_low_s8 (v0_1hs ), vget_low_s8 (v1_1hs ));
1389
+ const int16x8_t ph1h = vmull_s8 (vget_high_s8 (v0_1hs ), vget_high_s8 (v1_1hs ));
1390
+
1391
+ const int16x8_t pl_0 = vaddq_s16 (pl0l , pl0h );
1392
+ const int16x8_t ph_0 = vaddq_s16 (ph0l , ph0h );
1393
+
1394
+ const int16x8_t pl_1 = vaddq_s16 (pl1l , pl1h );
1395
+ const int16x8_t ph_1 = vaddq_s16 (ph1l , ph1h );
1396
+
1397
+ const int16x8_t p_0 = vaddq_s16 (pl_0 , ph_0 );
1398
+ const int16x8_t p_1 = vaddq_s16 (pl_1 , ph_1 );
1399
+
1400
+ // scalar
1401
+ #if defined(__ARM_FEATURE_QRDMX )
1402
+ sum0 += d0_0 * d1_0 * vaddvq_s16 (p_0 );
1403
+ sum1 += d0_1 * d1_1 * vaddvq_s16 (p_1 );
1404
+ #else
1405
+ sum0 += d0_0 * d1_0 * (vgetq_lane_s16 (p_0 , 0 ) + vgetq_lane_s16 (p_0 , 1 ) + vgetq_lane_s16 (p_0 , 2 ) + vgetq_lane_s16 (p_0 , 3 ) + vgetq_lane_s16 (p_0 , 4 ) + vgetq_lane_s16 (p_0 , 5 ) + vgetq_lane_s16 (p_0 , 6 ) + vgetq_lane_s16 (p_0 , 7 ));
1406
+ sum1 += d0_1 * d1_1 * (vgetq_lane_s16 (p_1 , 0 ) + vgetq_lane_s16 (p_1 , 1 ) + vgetq_lane_s16 (p_1 , 2 ) + vgetq_lane_s16 (p_1 , 3 ) + vgetq_lane_s16 (p_1 , 4 ) + vgetq_lane_s16 (p_1 , 5 ) + vgetq_lane_s16 (p_1 , 6 ) + vgetq_lane_s16 (p_1 , 7 ));
1407
+ #endif
1377
1408
#endif
1378
1409
}
1379
1410
0 commit comments