@@ -272,10 +272,13 @@ static inline float hsum_float_4x4(const __m128 a, const __m128 b, const __m128
272
272
273
273
// vaddvq_s16
274
274
// vpaddq_s16
275
+ // vpaddq_s32
275
276
// vaddvq_s32
276
277
// vaddvq_f32
277
278
// vmaxvq_f32
278
279
// vcvtnq_s32_f32
280
+ // vzip1_u8
281
+ // vzip2_u8
279
282
280
283
inline static int32_t vaddvq_s16 (int16x8_t v ) {
281
284
return
@@ -291,6 +294,12 @@ inline static int16x8_t vpaddq_s16(int16x8_t a, int16x8_t b) {
291
294
return vcombine_s16 (a0 , b0 );
292
295
}
293
296
297
+ inline static int32x4_t vpaddq_s32 (int32x4_t a , int32x4_t b ) {
298
+ int32x2_t a0 = vpadd_s32 (vget_low_s32 (a ), vget_high_s32 (a ));
299
+ int32x2_t b0 = vpadd_s32 (vget_low_s32 (b ), vget_high_s32 (b ));
300
+ return vcombine_s32 (a0 , b0 );
301
+ }
302
+
294
303
inline static int32_t vaddvq_s32 (int32x4_t v ) {
295
304
return vgetq_lane_s32 (v , 0 ) + vgetq_lane_s32 (v , 1 ) + vgetq_lane_s32 (v , 2 ) + vgetq_lane_s32 (v , 3 );
296
305
}
@@ -316,6 +325,28 @@ inline static int32x4_t vcvtnq_s32_f32(float32x4_t v) {
316
325
return res ;
317
326
}
318
327
328
+ inline static uint8x8_t vzip1_u8 (uint8x8_t a , uint8x8_t b ) {
329
+ uint8x8_t res ;
330
+
331
+ res [0 ] = a [0 ]; res [1 ] = b [0 ];
332
+ res [2 ] = a [1 ]; res [3 ] = b [1 ];
333
+ res [4 ] = a [2 ]; res [5 ] = b [2 ];
334
+ res [6 ] = a [3 ]; res [7 ] = b [3 ];
335
+
336
+ return res ;
337
+ }
338
+
339
+ inline static uint8x8_t vzip2_u8 (uint8x8_t a , uint8x8_t b ) {
340
+ uint8x8_t res ;
341
+
342
+ res [0 ] = a [4 ]; res [1 ] = b [4 ];
343
+ res [2 ] = a [5 ]; res [3 ] = b [5 ];
344
+ res [4 ] = a [6 ]; res [5 ] = b [6 ];
345
+ res [6 ] = a [7 ]; res [7 ] = b [7 ];
346
+
347
+ return res ;
348
+ }
349
+
319
350
// vld1q_s16_x2
320
351
// vld1q_u8_x2
321
352
// vld1q_u8_x4
@@ -7554,9 +7585,9 @@ void ggml_vec_dot_iq2_xs_q8_K(const int n, float * restrict s, const void * rest
7554
7585
7555
7586
const uint64_t * signs64 = (const uint64_t * )keven_signs_q2xs ;
7556
7587
7557
- int8x16x4_t q2u ;
7558
- int8x16x4_t q2s ;
7559
- int8x16x4_t q8b ;
7588
+ ggml_int8x16x4_t q2u ;
7589
+ ggml_int8x16x4_t q2s ;
7590
+ ggml_int8x16x4_t q8b ;
7560
7591
7561
7592
int32x4x4_t scales32 ;
7562
7593
@@ -7578,7 +7609,7 @@ void ggml_vec_dot_iq2_xs_q8_K(const int n, float * restrict s, const void * rest
7578
7609
scales32 .val [3 ] = vreinterpretq_s32_u32 (vmovl_u16 (vget_high_u16 (scales2 )));
7579
7610
int32x4_t sumi = vdupq_n_s32 (0 );
7580
7611
for (int ib64 = 0 ; ib64 < QK_K /64 ; ++ ib64 ) {
7581
- q8b = vld1q_s8_x4 (q8 ); q8 += 64 ;
7612
+ q8b = ggml_vld1q_s8_x4 (q8 ); q8 += 64 ;
7582
7613
q2u .val [0 ] = vcombine_s8 (vld1_s8 ((const void * )(iq2xs_grid + (q2 [0 ] & 511 ))), vld1_s8 ((const void * )(iq2xs_grid + (q2 [1 ] & 511 ))));
7583
7614
q2u .val [1 ] = vcombine_s8 (vld1_s8 ((const void * )(iq2xs_grid + (q2 [2 ] & 511 ))), vld1_s8 ((const void * )(iq2xs_grid + (q2 [3 ] & 511 ))));
7584
7615
q2u .val [2 ] = vcombine_s8 (vld1_s8 ((const void * )(iq2xs_grid + (q2 [4 ] & 511 ))), vld1_s8 ((const void * )(iq2xs_grid + (q2 [5 ] & 511 ))));
0 commit comments