@@ -491,6 +491,32 @@ static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 )
491
491
}
492
492
#endif
493
493
494
+ #if __ARM_NEON
495
+ #if !defined(__ARM_FEATURE_QRDMX )
496
+
497
+ inline static int16_t vaddvq_s16 (int16x8_t v ) {
498
+ const int16x4_t v1 = vadd_s16 (vget_low_s16 (v ), vget_high_s16 (v ));
499
+ return vaddv_s16 (v1 );
500
+ }
501
+
502
+ inline static uint16_t vaddvq_u16 (uint16x8_t v ) {
503
+ const uint16x4_t v1 = vadd_u16 (vget_low_u16 (v ), vget_high_u16 (v ));
504
+ return vaddv_u16 (v1 );
505
+ }
506
+
507
+ inline static int32_t vaddvq_s32 (int32x4_t v ) {
508
+ const int32x2_t v1 = vadd_s32 (vget_low_s32 (v ), vget_high_s32 (v ));
509
+ return vaddv_s32 (v1 );
510
+ }
511
+
512
+ inline static float vaddvq_f32 (float32x4_t v ) {
513
+ const float32x2_t v1 = vadd_f32 (vget_low_f32 (v ), vget_high_f32 (v ));
514
+ return vaddv_f32 (v1 );
515
+ }
516
+
517
+ #endif
518
+ #endif
519
+
494
520
// method 5
495
521
// blocks of QK elements
496
522
// represented with a single float (delta) and QK/2 8-bit ints (i.e QK 4-bit signed integer factors)
@@ -1218,15 +1244,7 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
1218
1244
#define GGML_F32x4_FMA (a , b , c ) vfmaq_f32(a, b, c)
1219
1245
#define GGML_F32x4_ADD vaddq_f32
1220
1246
#define GGML_F32x4_MUL vmulq_f32
1221
- #if defined(__ARM_FEATURE_QRDMX )
1222
- #define GGML_F32x4_REDUCE_ONE (x ) vaddvq_f32(x)
1223
- #else
1224
- #define GGML_F32x4_REDUCE_ONE (x ) \
1225
- (vgetq_lane_f32(x, 0) + \
1226
- vgetq_lane_f32(x, 1) + \
1227
- vgetq_lane_f32(x, 2) + \
1228
- vgetq_lane_f32(x, 3))
1229
- #endif
1247
+ #define GGML_F32x4_REDUCE_ONE (x ) vaddvq_f32(x)
1230
1248
#define GGML_F32x4_REDUCE (res , x ) \
1231
1249
{ \
1232
1250
for (int i = 0; i < GGML_F32_ARR/2; ++i) { \
@@ -1849,55 +1867,43 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
1849
1867
// 4-bit -> 8-bit
1850
1868
const int8x16_t v0_0l = vreinterpretq_s8_u8 (vandq_u8 (v0_0 , m4b ));
1851
1869
const int8x16_t v1_0l = vreinterpretq_s8_u8 (vandq_u8 (v1_0 , m4b ));
1852
-
1853
1870
const int8x16_t v0_0h = vreinterpretq_s8_u8 (vshrq_n_u8 (v0_0 , 4 ));
1854
1871
const int8x16_t v1_0h = vreinterpretq_s8_u8 (vshrq_n_u8 (v1_0 , 4 ));
1855
1872
1856
1873
const int8x16_t v0_1l = vreinterpretq_s8_u8 (vandq_u8 (v0_1 , m4b ));
1857
1874
const int8x16_t v1_1l = vreinterpretq_s8_u8 (vandq_u8 (v1_1 , m4b ));
1858
-
1859
1875
const int8x16_t v0_1h = vreinterpretq_s8_u8 (vshrq_n_u8 (v0_1 , 4 ));
1860
1876
const int8x16_t v1_1h = vreinterpretq_s8_u8 (vshrq_n_u8 (v1_1 , 4 ));
1861
1877
1862
1878
// sub 8
1863
1879
const int8x16_t v0_0ls = vsubq_s8 (v0_0l , s8b );
1864
1880
const int8x16_t v1_0ls = vsubq_s8 (v1_0l , s8b );
1865
-
1866
1881
const int8x16_t v0_0hs = vsubq_s8 (v0_0h , s8b );
1867
1882
const int8x16_t v1_0hs = vsubq_s8 (v1_0h , s8b );
1868
1883
1869
1884
const int8x16_t v0_1ls = vsubq_s8 (v0_1l , s8b );
1870
1885
const int8x16_t v1_1ls = vsubq_s8 (v1_1l , s8b );
1871
-
1872
1886
const int8x16_t v0_1hs = vsubq_s8 (v0_1h , s8b );
1873
1887
const int8x16_t v1_1hs = vsubq_s8 (v1_1h , s8b );
1874
1888
1875
1889
#if defined(__ARM_FEATURE_DOTPROD )
1876
- // dot product into int16x8_t
1890
+ // dot product into int32x4_t
1877
1891
int32x4_t p_0 = vdotq_s32 (vdupq_n_s32 (0 ), v0_0ls , v1_0ls );
1878
1892
int32x4_t p_1 = vdotq_s32 (vdupq_n_s32 (0 ), v0_1ls , v1_1ls );
1879
1893
1880
1894
p_0 = vdotq_s32 (p_0 , v0_0hs , v1_0hs );
1881
1895
p_1 = vdotq_s32 (p_1 , v0_1hs , v1_1hs );
1882
1896
1883
- // scalar
1884
- #if defined(__ARM_FEATURE_QRDMX )
1885
- sum0 += x0 -> d * y0 -> d * vaddvq_s32 (p_0 );
1886
- sum1 += x1 -> d * y1 -> d * vaddvq_s32 (p_1 );
1887
- #else
1888
- sum0 += x0 -> d * y0 -> d * (vgetq_lane_s32 (p_0 , 0 ) + vgetq_lane_s32 (p_0 , 1 ) + vgetq_lane_s32 (p_0 , 2 ) + vgetq_lane_s32 (p_0 , 3 ));
1889
- sum1 += x1 -> d * y1 -> d * (vgetq_lane_s32 (p_1 , 0 ) + vgetq_lane_s32 (p_1 , 1 ) + vgetq_lane_s32 (p_1 , 2 ) + vgetq_lane_s32 (p_1 , 3 ));
1890
- #endif
1897
+ sum0 += x0 -> d * y0 -> d * vaddvq_s32 (p_0 );
1898
+ sum1 += x1 -> d * y1 -> d * vaddvq_s32 (p_1 );
1891
1899
#else
1892
1900
const int16x8_t pl0l = vmull_s8 (vget_low_s8 (v0_0ls ), vget_low_s8 (v1_0ls ));
1893
1901
const int16x8_t pl0h = vmull_s8 (vget_high_s8 (v0_0ls ), vget_high_s8 (v1_0ls ));
1894
-
1895
1902
const int16x8_t ph0l = vmull_s8 (vget_low_s8 (v0_0hs ), vget_low_s8 (v1_0hs ));
1896
1903
const int16x8_t ph0h = vmull_s8 (vget_high_s8 (v0_0hs ), vget_high_s8 (v1_0hs ));
1897
1904
1898
1905
const int16x8_t pl1l = vmull_s8 (vget_low_s8 (v0_1ls ), vget_low_s8 (v1_1ls ));
1899
1906
const int16x8_t pl1h = vmull_s8 (vget_high_s8 (v0_1ls ), vget_high_s8 (v1_1ls ));
1900
-
1901
1907
const int16x8_t ph1l = vmull_s8 (vget_low_s8 (v0_1hs ), vget_low_s8 (v1_1hs ));
1902
1908
const int16x8_t ph1h = vmull_s8 (vget_high_s8 (v0_1hs ), vget_high_s8 (v1_1hs ));
1903
1909
@@ -1910,14 +1916,8 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
1910
1916
const int16x8_t p_0 = vaddq_s16 (pl_0 , ph_0 );
1911
1917
const int16x8_t p_1 = vaddq_s16 (pl_1 , ph_1 );
1912
1918
1913
- // scalar
1914
- #if defined(__ARM_FEATURE_QRDMX )
1915
- sum0 += x0 -> d * y0 -> d * vaddvq_s16 (p_0 );
1916
- sum1 += x1 -> d * y1 -> d * vaddvq_s16 (p_1 );
1917
- #else
1918
- sum0 += x0 -> d * y0 -> d * (vgetq_lane_s16 (p_0 , 0 ) + vgetq_lane_s16 (p_0 , 1 ) + vgetq_lane_s16 (p_0 , 2 ) + vgetq_lane_s16 (p_0 , 3 ) + vgetq_lane_s16 (p_0 , 4 ) + vgetq_lane_s16 (p_0 , 5 ) + vgetq_lane_s16 (p_0 , 6 ) + vgetq_lane_s16 (p_0 , 7 ));
1919
- sum1 += x1 -> d * y1 -> d * (vgetq_lane_s16 (p_1 , 0 ) + vgetq_lane_s16 (p_1 , 1 ) + vgetq_lane_s16 (p_1 , 2 ) + vgetq_lane_s16 (p_1 , 3 ) + vgetq_lane_s16 (p_1 , 4 ) + vgetq_lane_s16 (p_1 , 5 ) + vgetq_lane_s16 (p_1 , 6 ) + vgetq_lane_s16 (p_1 , 7 ));
1920
- #endif
1919
+ sum0 += x0 -> d * y0 -> d * vaddvq_s16 (p_0 );
1920
+ sum1 += x1 -> d * y1 -> d * vaddvq_s16 (p_1 );
1921
1921
#endif
1922
1922
}
1923
1923
@@ -2265,36 +2265,71 @@ static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * rest
2265
2265
float sum10 = 0.0f ;
2266
2266
float sum11 = 0.0f ;
2267
2267
2268
- for (int i = 0 ; i < nb ; ++ i ) {
2268
+ for (int i = 0 ; i < nb ; i += 2 ) {
2269
2269
const block_q4_1 * restrict x0 = & x [i + 0 ];
2270
2270
const block_q4_1 * restrict y0 = & y [i + 0 ];
2271
+ const block_q4_1 * restrict x1 = & x [i + 1 ];
2272
+ const block_q4_1 * restrict y1 = & y [i + 1 ];
2271
2273
2272
2274
const uint8x16_t m4b = vdupq_n_u8 (0xf );
2273
2275
2274
2276
const uint8x16_t v0_0 = vld1q_u8 (x0 -> qs );
2275
2277
const uint8x16_t v1_0 = vld1q_u8 (y0 -> qs );
2278
+ const uint8x16_t v0_1 = vld1q_u8 (x1 -> qs );
2279
+ const uint8x16_t v1_1 = vld1q_u8 (y1 -> qs );
2276
2280
2277
- // and with 0xf
2281
+ // 4-bit -> 8-bit
2278
2282
const uint8x16_t v0_0l = vandq_u8 (v0_0 , m4b );
2279
2283
const uint8x16_t v1_0l = vandq_u8 (v1_0 , m4b );
2280
-
2281
2284
const uint8x16_t v0_0h = vshrq_n_u8 (v0_0 , 4 );
2282
2285
const uint8x16_t v1_0h = vshrq_n_u8 (v1_0 , 4 );
2283
2286
2284
- // dot product into uint16x8_t
2287
+ const uint8x16_t v0_1l = vandq_u8 (v0_1 , m4b );
2288
+ const uint8x16_t v1_1l = vandq_u8 (v1_1 , m4b );
2289
+ const uint8x16_t v0_1h = vshrq_n_u8 (v0_1 , 4 );
2290
+ const uint8x16_t v1_1h = vshrq_n_u8 (v1_1 , 4 );
2291
+
2292
+ sum00 += x0 -> m * y0 -> m ;
2293
+ sum01 += y0 -> m * x0 -> d * (vaddvq_u8 (v0_0l ) + vaddvq_u8 (v0_0h ));
2294
+ sum10 += x0 -> m * y0 -> d * (vaddvq_u8 (v1_0l ) + vaddvq_u8 (v1_0h ));
2295
+
2296
+ sum00 += x1 -> m * y1 -> m ;
2297
+ sum01 += y1 -> m * x1 -> d * (vaddvq_u8 (v0_1l ) + vaddvq_u8 (v0_1h ));
2298
+ sum10 += x1 -> m * y1 -> d * (vaddvq_u8 (v1_1l ) + vaddvq_u8 (v1_1h ));
2299
+
2300
+ #if defined(__ARM_FEATURE_DOTPROD )
2301
+ // dot product into int32x4_t
2302
+ int32x4_t p_0 = vdotq_s32 (vdupq_n_s32 (0 ), v0_0l , v1_0l );
2303
+ int32x4_t p_1 = vdotq_s32 (vdupq_n_s32 (0 ), v0_1l , v1_1l );
2304
+
2305
+ p_0 = vdotq_s32 (p_0 , v0_0h , v1_0h );
2306
+ p_1 = vdotq_s32 (p_1 , v0_1h , v1_1h );
2307
+
2308
+ sum11 += x0 -> d * y0 -> d * vaddvq_s32 (p_0 );
2309
+ sum11 += x1 -> d * y1 -> d * vaddvq_s32 (p_1 );
2310
+ #else
2285
2311
const uint16x8_t pl0l = vmull_u8 (vget_low_u8 (v0_0l ), vget_low_u8 (v1_0l ));
2286
2312
const uint16x8_t pl0h = vmull_u8 (vget_high_u8 (v0_0l ), vget_high_u8 (v1_0l ));
2287
-
2288
2313
const uint16x8_t ph0l = vmull_u8 (vget_low_u8 (v0_0h ), vget_low_u8 (v1_0h ));
2289
2314
const uint16x8_t ph0h = vmull_u8 (vget_high_u8 (v0_0h ), vget_high_u8 (v1_0h ));
2290
2315
2291
- const uint16x8_t pl0 = vaddq_u16 (pl0l , pl0h );
2292
- const uint16x8_t ph0 = vaddq_u16 (ph0l , ph0h );
2316
+ const uint16x8_t pl1l = vmull_u8 (vget_low_s8 (v0_1l ), vget_low_u8 (v1_1l ));
2317
+ const uint16x8_t pl1h = vmull_u8 (vget_high_s8 (v0_1l ), vget_high_u8 (v1_1l ));
2318
+ const uint16x8_t ph1l = vmull_u8 (vget_low_s8 (v0_1h ), vget_low_u8 (v1_1h ));
2319
+ const uint16x8_t ph1h = vmull_u8 (vget_high_s8 (v0_1h ), vget_high_u8 (v1_1h ));
2293
2320
2294
- sum00 += x0 -> m * y0 -> m ;
2295
- sum01 += y0 -> m * x0 -> d * (vaddvq_u8 (v0_0l ) + vaddvq_u8 (v0_0h ));
2296
- sum10 += x0 -> m * y0 -> d * (vaddvq_u8 (v1_0l ) + vaddvq_u8 (v1_0h ));
2297
- sum11 += x0 -> d * y0 -> d * vaddvq_u16 (vaddq_u16 (pl0 , ph0 ));
2321
+ const uint16x8_t pl_0 = vaddq_u16 (pl0l , pl0h );
2322
+ const uint16x8_t ph_0 = vaddq_u16 (ph0l , ph0h );
2323
+
2324
+ const uint16x8_t pl_1 = vaddq_u16 (pl1l , pl1h );
2325
+ const uint16x8_t ph_1 = vaddq_u16 (ph1l , ph1h );
2326
+
2327
+ const uint16x8_t p_0 = vaddq_u16 (pl_0 , ph_0 );
2328
+ const uint16x8_t p_1 = vaddq_u16 (pl_1 , ph_1 );
2329
+
2330
+ sum11 += x0 -> d * y0 -> d * vaddvq_u16 (p_0 );
2331
+ sum11 += x1 -> d * y1 -> d * vaddvq_u16 (p_1 );
2332
+ #endif
2298
2333
}
2299
2334
2300
2335
sumf = QK * sum00 + sum01 + sum10 + sum11 ;
0 commit comments