@@ -564,21 +564,21 @@ static void ggml_gemv_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, c
564
564
565
565
#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
566
566
if (ggml_cpu_has_neon () && ggml_cpu_has_dotprod ()) {
567
- const block_q4_0x4 * b_ptr = (const block_q4_0x4 *)vx;
567
+ const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx;
568
568
569
569
for (int c = 0 ; c < nc; c += ncols_interleaved) {
570
- const block_q8_0 * a_ptr = (const block_q8_0 *)vy;
570
+ const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
571
571
float32x4_t acc = vdupq_n_f32 (0 );
572
572
for (int b = 0 ; b < nb; b++) {
573
- int8x16_t b0 = vld1q_s8 ((const int8_t *)b_ptr->qs );
574
- int8x16_t b1 = vld1q_s8 ((const int8_t *)b_ptr->qs + 16 );
575
- int8x16_t b2 = vld1q_s8 ((const int8_t *)b_ptr->qs + 32 );
576
- int8x16_t b3 = vld1q_s8 ((const int8_t *)b_ptr->qs + 48 );
577
- float16x4_t bd = vld1_f16 ((const __fp16 *)b_ptr->d );
573
+ int8x16_t b0 = vld1q_s8 ((const int8_t *) b_ptr->qs );
574
+ int8x16_t b1 = vld1q_s8 ((const int8_t *) b_ptr->qs + 16 );
575
+ int8x16_t b2 = vld1q_s8 ((const int8_t *) b_ptr->qs + 32 );
576
+ int8x16_t b3 = vld1q_s8 ((const int8_t *) b_ptr->qs + 48 );
577
+ float16x4_t bd = vld1_f16 ((const __fp16 *) b_ptr->d );
578
578
579
579
int8x16_t a0 = vld1q_s8 (a_ptr->qs );
580
580
int8x16_t a1 = vld1q_s8 (a_ptr->qs + qk/2 );
581
- float16x4_t ad = vld1_dup_f16 ((const __fp16 *)&a_ptr->d );
581
+ float16x4_t ad = vld1_dup_f16 ((const __fp16 *) &a_ptr->d );
582
582
583
583
int32x4_t ret = vdupq_n_s32 (0 );
584
584
@@ -647,72 +647,52 @@ static void ggml_gemv_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, c
647
647
UNUSED (ncols_interleaved);
648
648
UNUSED (blocklen);
649
649
650
- #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
651
- if (ggml_cpu_has_neon () && ggml_cpu_has_matmul_int8 ()) {
652
- const void * b_ptr = vx;
653
- const void * a_ptr = vy;
654
- float * res_ptr = s;
650
+ #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
651
+ if (ggml_cpu_has_neon () && ggml_cpu_has_dotprod ()) {
652
+ const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx;
655
653
656
- __asm__ __volatile__ (
657
- " movi v2.16b, #0x4\n "
658
- " movi v1.16b, #0xf0\n "
659
- " add %x[b_ptr], %x[b_ptr], #0x8\n "
660
- " 1:" // Column loop
661
- " add x23, %x[a_ptr], #0x2\n "
662
- " movi v0.16b, #0x0\n "
663
- " mov x22, %x[nb]\n "
664
- " 2:" // Block loop
665
- " ldr q31, [%x[b_ptr], #0x0]\n "
666
- " ldr q30, [%x[b_ptr], #0x10]\n "
667
- " mov x21, x23\n "
668
- " movi v29.4s, #0x0\n "
669
- " ldr q28, [%x[b_ptr], #0x20]\n "
670
- " ldr q27, [%x[b_ptr], #0x30]\n "
671
- " movi v26.4s, #0x0\n "
672
- " sub x20, x23, #0x2\n "
673
- " ld1r { v25.8h }, [x20]\n "
674
- " ldr q24, [%x[b_ptr], #-0x8]\n "
675
- " sub x22, x22, #0x1\n "
676
- " add x23, x23, #0x22\n "
677
- " ld1r { v23.2d }, [x21], #0x8\n "
678
- " sshl v22.16b, v31.16b, v2.16b\n "
679
- " sshl v16.16b, v30.16b, v2.16b\n "
680
- " add %x[b_ptr], %x[b_ptr], #0x48\n "
681
- " ld1r { v21.2d }, [x21], #0x8\n "
682
- " sshl v20.16b, v28.16b, v2.16b\n "
683
- " sshl v19.16b, v27.16b, v2.16b\n "
684
- " ld1r { v18.2d }, [x21], #0x8\n "
685
- " ld1r { v17.2d }, [x21], #0x8\n "
686
- " and v31.16b, v31.16b, v1.16b\n "
687
- " and v30.16b, v30.16b, v1.16b\n "
688
- " .inst 0x4e9796dd // sdot v29.4s, v22.16b, v23.16b\n "
689
- " .inst 0x4e97961a // sdot v26.4s, v16.16b, v23.16b\n "
690
- " and v28.16b, v28.16b, v1.16b\n "
691
- " and v27.16b, v27.16b, v1.16b\n "
692
- " fcvtl v25.4s, v25.4h\n "
693
- " fcvtl v16.4s, v24.4h\n "
694
- " .inst 0x4e95969d // sdot v29.4s, v20.16b, v21.16b\n "
695
- " .inst 0x4e95967a // sdot v26.4s, v19.16b, v21.16b\n "
696
- " fmul v16.4s, v16.4s, v25.4s\n "
697
- " .inst 0x4e9297fd // sdot v29.4s, v31.16b, v18.16b\n "
698
- " .inst 0x4e9297da // sdot v26.4s, v30.16b, v18.16b\n "
699
- " .inst 0x4e91979d // sdot v29.4s, v28.16b, v17.16b\n "
700
- " .inst 0x4e91977a // sdot v26.4s, v27.16b, v17.16b\n "
701
- " addp v29.4s, v29.4s, v26.4s\n "
702
- " scvtf v29.4s, v29.4s, #0x4\n "
703
- " fmla v0.4s, v29.4s, v16.4s\n "
704
- " cbnz x22, 2b\n "
705
- " sub %x[nc], %x[nc], #0x4\n "
706
- " str q0, [%x[res_ptr], #0x0]\n "
707
- " add %x[res_ptr], %x[res_ptr], #0x10\n "
708
- " cbnz %x[nc], 1b\n "
709
- : [b_ptr] " +&r" (b_ptr), [res_ptr] " +&r" (res_ptr), [nc] " +&r" (nc)
710
- : [a_ptr] " r" (a_ptr), [nb] " r" (nb)
711
- : " memory" , " v0" , " v1" , " v2" , " v16" , " v17" , " v18" , " v19" , " v20" , " v21" , " v22" , " v23" , " v24" , " v25" , " v26" , " v27" , " v28" , " v29" , " v30" , " v31" , " x20" , " x21" , " x22" , " x23"
712
- );
654
+ for (int c = 0 ; c < nc; c += ncols_interleaved) {
655
+ const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
656
+ float32x4_t acc = vdupq_n_f32 (0 );
657
+ for (int b = 0 ; b < nb; b++) {
658
+ int8x16_t b0 = vld1q_s8 ((const int8_t *) b_ptr->qs );
659
+ int8x16_t b1 = vld1q_s8 ((const int8_t *) b_ptr->qs + 16 );
660
+ int8x16_t b2 = vld1q_s8 ((const int8_t *) b_ptr->qs + 32 );
661
+ int8x16_t b3 = vld1q_s8 ((const int8_t *) b_ptr->qs + 48 );
662
+ float16x4_t bd = vld1_f16 ((const __fp16 *) b_ptr->d );
663
+
664
+ int8x16_t a0 = (int8x16_t ) vld1q_dup_s64 ((const int64_t *) a_ptr->qs );
665
+ int8x16_t a1 = (int8x16_t ) vld1q_dup_s64 ((const int64_t *) a_ptr->qs + 1 );
666
+ int8x16_t a2 = (int8x16_t ) vld1q_dup_s64 ((const int64_t *) a_ptr->qs + 2 );
667
+ int8x16_t a3 = (int8x16_t ) vld1q_dup_s64 ((const int64_t *) a_ptr->qs + 3 );
668
+ float16x4_t ad = vld1_dup_f16 ((const __fp16 *) &a_ptr->d );
669
+
670
+ int32x4_t ret0 = vdupq_n_s32 (0 );
671
+ int32x4_t ret1 = vdupq_n_s32 (0 );
672
+
673
+ ret0 = vdotq_s32 (ret0, b0 << 4 , a0);
674
+ ret1 = vdotq_s32 (ret1, b1 << 4 , a0);
675
+ ret0 = vdotq_s32 (ret0, b2 << 4 , a1);
676
+ ret1 = vdotq_s32 (ret1, b3 << 4 , a1);
677
+
678
+ ret0 = vdotq_s32 (ret0, b0 & 0xf0U , a2);
679
+ ret1 = vdotq_s32 (ret1, b1 & 0xf0U , a2);
680
+ ret0 = vdotq_s32 (ret0, b2 & 0xf0U , a3);
681
+ ret1 = vdotq_s32 (ret1, b3 & 0xf0U , a3);
682
+
683
+ int32x4_t ret = vpaddq_s32 (ret0, ret1);
684
+
685
+ acc = vfmaq_f32 (acc, vcvtq_n_f32_s32 (ret, 4 ),
686
+ vmulq_f32 (vcvt_f32_f16 (ad), vcvt_f32_f16 (bd)));
687
+ a_ptr++;
688
+ b_ptr++;
689
+ }
690
+ vst1q_f32 (s, acc);
691
+ s += ncols_interleaved;
692
+ }
713
693
return ;
714
694
}
715
- #endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8 )
695
+ #endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD )
716
696
float sumf[4 ];
717
697
int sumi;
718
698
0 commit comments