Skip to content

Commit e34c5af

Browse files
authored
ggml-cpu: replace NEON asm with intrinsics in ggml_gemv_q4_0_4x8_q8_0() (#10874)
* ggml-cpu: replace NEON asm with intrinsics in ggml_gemv_q4_0_4x8_q8_0() Signed-off-by: Adrien Gallouët <[email protected]> * ggml-cpu: format code Signed-off-by: Adrien Gallouët <[email protected]> --------- Signed-off-by: Adrien Gallouët <[email protected]>
1 parent eb5c3dc commit e34c5af

File tree

1 file changed

+51
-71
lines changed

1 file changed

+51
-71
lines changed

ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp

Lines changed: 51 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -564,21 +564,21 @@ static void ggml_gemv_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, c
564564

565565
#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
566566
if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) {
567-
const block_q4_0x4 * b_ptr = (const block_q4_0x4 *)vx;
567+
const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx;
568568

569569
for (int c = 0; c < nc; c += ncols_interleaved) {
570-
const block_q8_0 * a_ptr = (const block_q8_0 *)vy;
570+
const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
571571
float32x4_t acc = vdupq_n_f32(0);
572572
for (int b = 0; b < nb; b++) {
573-
int8x16_t b0 = vld1q_s8((const int8_t *)b_ptr->qs);
574-
int8x16_t b1 = vld1q_s8((const int8_t *)b_ptr->qs + 16);
575-
int8x16_t b2 = vld1q_s8((const int8_t *)b_ptr->qs + 32);
576-
int8x16_t b3 = vld1q_s8((const int8_t *)b_ptr->qs + 48);
577-
float16x4_t bd = vld1_f16((const __fp16 *)b_ptr->d);
573+
int8x16_t b0 = vld1q_s8((const int8_t *) b_ptr->qs);
574+
int8x16_t b1 = vld1q_s8((const int8_t *) b_ptr->qs + 16);
575+
int8x16_t b2 = vld1q_s8((const int8_t *) b_ptr->qs + 32);
576+
int8x16_t b3 = vld1q_s8((const int8_t *) b_ptr->qs + 48);
577+
float16x4_t bd = vld1_f16((const __fp16 *) b_ptr->d);
578578

579579
int8x16_t a0 = vld1q_s8(a_ptr->qs);
580580
int8x16_t a1 = vld1q_s8(a_ptr->qs + qk/2);
581-
float16x4_t ad = vld1_dup_f16((const __fp16 *)&a_ptr->d);
581+
float16x4_t ad = vld1_dup_f16((const __fp16 *) &a_ptr->d);
582582

583583
int32x4_t ret = vdupq_n_s32(0);
584584

@@ -647,72 +647,52 @@ static void ggml_gemv_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, c
647647
UNUSED(ncols_interleaved);
648648
UNUSED(blocklen);
649649

650-
#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
651-
if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) {
652-
const void * b_ptr = vx;
653-
const void * a_ptr = vy;
654-
float * res_ptr = s;
650+
#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
651+
if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) {
652+
const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx;
655653

656-
__asm__ __volatile__(
657-
"movi v2.16b, #0x4\n"
658-
"movi v1.16b, #0xf0\n"
659-
"add %x[b_ptr], %x[b_ptr], #0x8\n"
660-
"1:" // Column loop
661-
"add x23, %x[a_ptr], #0x2\n"
662-
"movi v0.16b, #0x0\n"
663-
"mov x22, %x[nb]\n"
664-
"2:" // Block loop
665-
"ldr q31, [%x[b_ptr], #0x0]\n"
666-
"ldr q30, [%x[b_ptr], #0x10]\n"
667-
"mov x21, x23\n"
668-
"movi v29.4s, #0x0\n"
669-
"ldr q28, [%x[b_ptr], #0x20]\n"
670-
"ldr q27, [%x[b_ptr], #0x30]\n"
671-
"movi v26.4s, #0x0\n"
672-
"sub x20, x23, #0x2\n"
673-
"ld1r { v25.8h }, [x20]\n"
674-
"ldr q24, [%x[b_ptr], #-0x8]\n"
675-
"sub x22, x22, #0x1\n"
676-
"add x23, x23, #0x22\n"
677-
"ld1r { v23.2d }, [x21], #0x8\n"
678-
"sshl v22.16b, v31.16b, v2.16b\n"
679-
"sshl v16.16b, v30.16b, v2.16b\n"
680-
"add %x[b_ptr], %x[b_ptr], #0x48\n"
681-
"ld1r { v21.2d }, [x21], #0x8\n"
682-
"sshl v20.16b, v28.16b, v2.16b\n"
683-
"sshl v19.16b, v27.16b, v2.16b\n"
684-
"ld1r { v18.2d }, [x21], #0x8\n"
685-
"ld1r { v17.2d }, [x21], #0x8\n"
686-
"and v31.16b, v31.16b, v1.16b\n"
687-
"and v30.16b, v30.16b, v1.16b\n"
688-
".inst 0x4e9796dd // sdot v29.4s, v22.16b, v23.16b\n"
689-
".inst 0x4e97961a // sdot v26.4s, v16.16b, v23.16b\n"
690-
"and v28.16b, v28.16b, v1.16b\n"
691-
"and v27.16b, v27.16b, v1.16b\n"
692-
"fcvtl v25.4s, v25.4h\n"
693-
"fcvtl v16.4s, v24.4h\n"
694-
".inst 0x4e95969d // sdot v29.4s, v20.16b, v21.16b\n"
695-
".inst 0x4e95967a // sdot v26.4s, v19.16b, v21.16b\n"
696-
"fmul v16.4s, v16.4s, v25.4s\n"
697-
".inst 0x4e9297fd // sdot v29.4s, v31.16b, v18.16b\n"
698-
".inst 0x4e9297da // sdot v26.4s, v30.16b, v18.16b\n"
699-
".inst 0x4e91979d // sdot v29.4s, v28.16b, v17.16b\n"
700-
".inst 0x4e91977a // sdot v26.4s, v27.16b, v17.16b\n"
701-
"addp v29.4s, v29.4s, v26.4s\n"
702-
"scvtf v29.4s, v29.4s, #0x4\n"
703-
"fmla v0.4s, v29.4s, v16.4s\n"
704-
"cbnz x22, 2b\n"
705-
"sub %x[nc], %x[nc], #0x4\n"
706-
"str q0, [%x[res_ptr], #0x0]\n"
707-
"add %x[res_ptr], %x[res_ptr], #0x10\n"
708-
"cbnz %x[nc], 1b\n"
709-
: [b_ptr] "+&r" (b_ptr), [res_ptr] "+&r" (res_ptr), [nc] "+&r" (nc)
710-
: [a_ptr] "r" (a_ptr), [nb] "r" (nb)
711-
: "memory", "v0", "v1", "v2", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x20", "x21", "x22", "x23"
712-
);
654+
for (int c = 0; c < nc; c += ncols_interleaved) {
655+
const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
656+
float32x4_t acc = vdupq_n_f32(0);
657+
for (int b = 0; b < nb; b++) {
658+
int8x16_t b0 = vld1q_s8((const int8_t *) b_ptr->qs);
659+
int8x16_t b1 = vld1q_s8((const int8_t *) b_ptr->qs + 16);
660+
int8x16_t b2 = vld1q_s8((const int8_t *) b_ptr->qs + 32);
661+
int8x16_t b3 = vld1q_s8((const int8_t *) b_ptr->qs + 48);
662+
float16x4_t bd = vld1_f16((const __fp16 *) b_ptr->d);
663+
664+
int8x16_t a0 = (int8x16_t) vld1q_dup_s64((const int64_t *) a_ptr->qs);
665+
int8x16_t a1 = (int8x16_t) vld1q_dup_s64((const int64_t *) a_ptr->qs + 1);
666+
int8x16_t a2 = (int8x16_t) vld1q_dup_s64((const int64_t *) a_ptr->qs + 2);
667+
int8x16_t a3 = (int8x16_t) vld1q_dup_s64((const int64_t *) a_ptr->qs + 3);
668+
float16x4_t ad = vld1_dup_f16((const __fp16 *) &a_ptr->d);
669+
670+
int32x4_t ret0 = vdupq_n_s32(0);
671+
int32x4_t ret1 = vdupq_n_s32(0);
672+
673+
ret0 = vdotq_s32(ret0, b0 << 4, a0);
674+
ret1 = vdotq_s32(ret1, b1 << 4, a0);
675+
ret0 = vdotq_s32(ret0, b2 << 4, a1);
676+
ret1 = vdotq_s32(ret1, b3 << 4, a1);
677+
678+
ret0 = vdotq_s32(ret0, b0 & 0xf0U, a2);
679+
ret1 = vdotq_s32(ret1, b1 & 0xf0U, a2);
680+
ret0 = vdotq_s32(ret0, b2 & 0xf0U, a3);
681+
ret1 = vdotq_s32(ret1, b3 & 0xf0U, a3);
682+
683+
int32x4_t ret = vpaddq_s32(ret0, ret1);
684+
685+
acc = vfmaq_f32(acc, vcvtq_n_f32_s32(ret, 4),
686+
vmulq_f32(vcvt_f32_f16(ad), vcvt_f32_f16(bd)));
687+
a_ptr++;
688+
b_ptr++;
689+
}
690+
vst1q_f32(s, acc);
691+
s += ncols_interleaved;
692+
}
713693
return;
714694
}
715-
#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
695+
#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
716696
float sumf[4];
717697
int sumi;
718698

0 commit comments

Comments
 (0)