@@ -462,6 +462,30 @@ inline static int8x16_t ggml_vqtbl1q_s8(int8x16_t a, uint8x16_t b) {
462
462
return res ;
463
463
}
464
464
465
+ // NOTE: not tested
466
+ inline static int8x16_t ggml_vqtbl1q_u8 (uint8x16_t a , uint8x16_t b ) {
467
+ int8x16_t res ;
468
+
469
+ res [ 0 ] = a [b [ 0 ]];
470
+ res [ 1 ] = a [b [ 1 ]];
471
+ res [ 2 ] = a [b [ 2 ]];
472
+ res [ 3 ] = a [b [ 3 ]];
473
+ res [ 4 ] = a [b [ 4 ]];
474
+ res [ 5 ] = a [b [ 5 ]];
475
+ res [ 6 ] = a [b [ 6 ]];
476
+ res [ 7 ] = a [b [ 7 ]];
477
+ res [ 8 ] = a [b [ 8 ]];
478
+ res [ 9 ] = a [b [ 9 ]];
479
+ res [10 ] = a [b [10 ]];
480
+ res [11 ] = a [b [11 ]];
481
+ res [12 ] = a [b [12 ]];
482
+ res [13 ] = a [b [13 ]];
483
+ res [14 ] = a [b [14 ]];
484
+ res [15 ] = a [b [15 ]];
485
+
486
+ return res ;
487
+ }
488
+
465
489
#else
466
490
467
491
#define ggml_int16x8x2_t int16x8x2_t
@@ -476,6 +500,7 @@ inline static int8x16_t ggml_vqtbl1q_s8(int8x16_t a, uint8x16_t b) {
476
500
#define ggml_vld1q_s8_x2 vld1q_s8_x2
477
501
#define ggml_vld1q_s8_x4 vld1q_s8_x4
478
502
#define ggml_vqtbl1q_s8 vqtbl1q_s8
503
+ #define ggml_vqtbl1q_u8 vqtbl1q_u8
479
504
480
505
#endif
481
506
@@ -9488,17 +9513,17 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const v
9488
9513
qs += 16 ;
9489
9514
9490
9515
vs .val [0 ] = vreinterpretq_u8_u32 (vdupq_n_u32 (signs [0 ] | (signs [1 ] << 16 )));
9491
- vs .val [1 ] = vandq_u8 (vqtbl1q_u8 (vs .val [0 ], mask1 .val [1 ]), mask2 );
9492
- vs .val [0 ] = vandq_u8 (vqtbl1q_u8 (vs .val [0 ], mask1 .val [0 ]), mask2 );
9516
+ vs .val [1 ] = vandq_u8 (ggml_vqtbl1q_u8 (vs .val [0 ], mask1 .val [1 ]), mask2 );
9517
+ vs .val [0 ] = vandq_u8 (ggml_vqtbl1q_u8 (vs .val [0 ], mask1 .val [0 ]), mask2 );
9493
9518
vs .val [0 ] = vceqq_u8 (vs .val [0 ], mask2 );
9494
9519
vs .val [1 ] = vceqq_u8 (vs .val [1 ], mask2 );
9495
9520
9496
9521
q3s .val [0 ] = vsubq_s8 (vreinterpretq_s8_u8 (veorq_u8 (vs .val [0 ], vreinterpretq_u8_u32 (aux32x4_0 ))), vreinterpretq_s8_u8 (vs .val [0 ]));
9497
9522
q3s .val [1 ] = vsubq_s8 (vreinterpretq_s8_u8 (veorq_u8 (vs .val [1 ], vreinterpretq_u8_u32 (aux32x4_1 ))), vreinterpretq_s8_u8 (vs .val [1 ]));
9498
9523
9499
9524
vs .val [0 ] = vreinterpretq_u8_u32 (vdupq_n_u32 (signs [2 ] | (signs [3 ] << 16 )));
9500
- vs .val [1 ] = vandq_u8 (vqtbl1q_u8 (vs .val [0 ], mask1 .val [1 ]), mask2 );
9501
- vs .val [0 ] = vandq_u8 (vqtbl1q_u8 (vs .val [0 ], mask1 .val [0 ]), mask2 );
9525
+ vs .val [1 ] = vandq_u8 (ggml_vqtbl1q_u8 (vs .val [0 ], mask1 .val [1 ]), mask2 );
9526
+ vs .val [0 ] = vandq_u8 (ggml_vqtbl1q_u8 (vs .val [0 ], mask1 .val [0 ]), mask2 );
9502
9527
vs .val [0 ] = vceqq_u8 (vs .val [0 ], mask2 );
9503
9528
vs .val [1 ] = vceqq_u8 (vs .val [1 ], mask2 );
9504
9529
0 commit comments