@@ -525,67 +525,47 @@ void ggml_gemv_q4_0_4x4_q8_0(int n, float * restrict s, size_t bs, const void *
525
525
UNUSED (ncols_interleaved );
526
526
UNUSED (blocklen );
527
527
528
- #if ! ((defined(_MSC_VER )) && ! defined(__clang__ )) && defined(__aarch64__ ) && defined(__ARM_NEON )
528
+ #if ! ((defined(_MSC_VER )) && ! defined(__clang__ )) && defined(__aarch64__ ) && defined(__ARM_NEON ) && defined( __ARM_FEATURE_DOTPROD )
529
529
if (ggml_cpu_has_neon () && ggml_cpu_has_dotprod ()) {
530
- const void * b_ptr = vx ;
531
- const void * a_ptr = vy ;
532
- float * res_ptr = s ;
533
-
534
- __asm__ __volatile__(
535
- "movi v31.16b, #0x4\n"
536
- "movi v30.16b, #0xf0\n"
537
- "add %x[b_ptr], %x[b_ptr], #0x8\n"
538
- "1:" // Column loop
539
- "add x22, %x[a_ptr], #0x2\n"
540
- "movi v29.16b, #0x0\n"
541
- "mov x21, %x[nb]\n"
542
- "2:" // Block loop
543
- "ldr q28, [%x[b_ptr], #0x0]\n"
544
- "ldr q27, [x22, #0x0]\n"
545
- "movi v26.4s, #0x0\n"
546
- "sub x20, x22, #0x2\n"
547
- "ldr q25, [x22, #0x10]\n"
548
- "ldr q24, [%x[b_ptr], #0x10]\n"
549
- "sub x21, x21, #0x1\n"
550
- "add x22, x22, #0x22\n"
551
- "ldr q23, [%x[b_ptr], #0x20]\n"
552
- "ldr q22, [%x[b_ptr], #0x30]\n"
553
- "ld1r { v21.8h }, [x20]\n"
554
- "ldr q20, [%x[b_ptr], #-0x8]\n"
555
- "sshl v16.16b, v28.16b, v31.16b\n"
556
- "and v28.16b, v28.16b, v30.16b\n"
557
- "sshl v19.16b, v24.16b, v31.16b\n"
558
- "and v24.16b, v24.16b, v30.16b\n"
559
- "add %x[b_ptr], %x[b_ptr], #0x48\n"
560
- "sshl v18.16b, v23.16b, v31.16b\n"
561
- "and v23.16b, v23.16b, v30.16b\n"
562
- ".inst 0x4f9be21a // sdot v26.4s, v16.16b, v27.4b[0]\n"
563
- "sshl v17.16b, v22.16b, v31.16b\n"
564
- "and v22.16b, v22.16b, v30.16b\n"
565
- "fcvtl v21.4s, v21.4h\n"
566
- "fcvtl v16.4s, v20.4h\n"
567
- ".inst 0x4f99e39a // sdot v26.4s, v28.16b, v25.4b[0]\n"
568
- "fmul v16.4s, v16.4s, v21.4s\n"
569
- ".inst 0x4fbbe27a // sdot v26.4s, v19.16b, v27.4b[1]\n"
570
- ".inst 0x4fb9e31a // sdot v26.4s, v24.16b, v25.4b[1]\n"
571
- ".inst 0x4f9bea5a // sdot v26.4s, v18.16b, v27.4b[2]\n"
572
- ".inst 0x4f99eafa // sdot v26.4s, v23.16b, v25.4b[2]\n"
573
- ".inst 0x4fbbea3a // sdot v26.4s, v17.16b, v27.4b[3]\n"
574
- ".inst 0x4fb9eada // sdot v26.4s, v22.16b, v25.4b[3]\n"
575
- "scvtf v26.4s, v26.4s, #0x4\n"
576
- "fmla v29.4s, v26.4s, v16.4s\n"
577
- "cbnz x21, 2b\n"
578
- "sub %x[nc], %x[nc], #0x4\n"
579
- "str q29, [%x[res_ptr], #0x0]\n"
580
- "add %x[res_ptr], %x[res_ptr], #0x10\n"
581
- "cbnz %x[nc], 1b\n"
582
- : [b_ptr ] "+&r" (b_ptr ), [res_ptr ] "+&r" (res_ptr ), [nc ] "+&r" (nc )
583
- : [a_ptr ] "r" (a_ptr ), [nb ] "r" (nb )
584
- : "memory" , "v16" , "v17" , "v18" , "v19" , "v20" , "v21" , "v22" , "v23" , "v24" , "v25" , "v26" , "v27" , "v28" , "v29" , "v30" , "v31" , "x20" , "x21" , "x22"
585
- );
530
+ const block_q4_0x4 * b_ptr = (const block_q4_0x4 * )vx ;
531
+
532
+ for (int c = 0 ; c < nc ; c += ncols_interleaved ) {
533
+ const block_q8_0 * a_ptr = (const block_q8_0 * )vy ;
534
+ float32x4_t acc = vdupq_n_f32 (0 );
535
+ for (int b = 0 ; b < nb ; b ++ ) {
536
+ int8x16_t b0 = vld1q_s8 ((const int8_t * )b_ptr -> qs );
537
+ int8x16_t b1 = vld1q_s8 ((const int8_t * )b_ptr -> qs + 16 );
538
+ int8x16_t b2 = vld1q_s8 ((const int8_t * )b_ptr -> qs + 32 );
539
+ int8x16_t b3 = vld1q_s8 ((const int8_t * )b_ptr -> qs + 48 );
540
+ float16x4_t bd = vld1_f16 ((const __fp16 * )b_ptr -> d );
541
+
542
+ int8x16_t a0 = vld1q_s8 (a_ptr -> qs );
543
+ int8x16_t a1 = vld1q_s8 (a_ptr -> qs + qk /2 );
544
+ float16x4_t ad = vld1_dup_f16 ((const __fp16 * )& a_ptr -> d );
545
+
546
+ int32x4_t ret = vdupq_n_s32 (0 );
547
+
548
+ ret = vdotq_laneq_s32 (ret , b0 << 4 , a0 , 0 );
549
+ ret = vdotq_laneq_s32 (ret , b1 << 4 , a0 , 1 );
550
+ ret = vdotq_laneq_s32 (ret , b2 << 4 , a0 , 2 );
551
+ ret = vdotq_laneq_s32 (ret , b3 << 4 , a0 , 3 );
552
+
553
+ ret = vdotq_laneq_s32 (ret , b0 & 0xf0U , a1 , 0 );
554
+ ret = vdotq_laneq_s32 (ret , b1 & 0xf0U , a1 , 1 );
555
+ ret = vdotq_laneq_s32 (ret , b2 & 0xf0U , a1 , 2 );
556
+ ret = vdotq_laneq_s32 (ret , b3 & 0xf0U , a1 , 3 );
557
+
558
+ acc = vfmaq_f32 (acc , vcvtq_n_f32_s32 (ret , 4 ),
559
+ vmulq_f32 (vcvt_f32_f16 (ad ), vcvt_f32_f16 (bd )));
560
+ a_ptr ++ ;
561
+ b_ptr ++ ;
562
+ }
563
+ vst1q_f32 (s , acc );
564
+ s += ncols_interleaved ;
565
+ }
586
566
return ;
587
567
}
588
- #endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON)
568
+ #endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
589
569
float sumf [4 ];
590
570
int sumi ;
591
571
0 commit comments