Skip to content

Commit 0c39f44

Browse files
authored
ggml-cpu: replace AArch64 NEON assembly with intrinsics in ggml_gemv_q4_0_4x4_q8_0() (#10567)
Signed-off-by: Adrien Gallouët <[email protected]>
1 parent 3e0ba0e commit 0c39f44

File tree

1 file changed

+38
-58
lines changed

1 file changed

+38
-58
lines changed

ggml/src/ggml-cpu/ggml-cpu-aarch64.c

Lines changed: 38 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -525,67 +525,47 @@ void ggml_gemv_q4_0_4x4_q8_0(int n, float * restrict s, size_t bs, const void *
525525
UNUSED(ncols_interleaved);
526526
UNUSED(blocklen);
527527

528-
#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON)
528+
#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
529529
if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) {
530-
const void * b_ptr = vx;
531-
const void * a_ptr = vy;
532-
float * res_ptr = s;
533-
534-
__asm__ __volatile__(
535-
"movi v31.16b, #0x4\n"
536-
"movi v30.16b, #0xf0\n"
537-
"add %x[b_ptr], %x[b_ptr], #0x8\n"
538-
"1:" // Column loop
539-
"add x22, %x[a_ptr], #0x2\n"
540-
"movi v29.16b, #0x0\n"
541-
"mov x21, %x[nb]\n"
542-
"2:" // Block loop
543-
"ldr q28, [%x[b_ptr], #0x0]\n"
544-
"ldr q27, [x22, #0x0]\n"
545-
"movi v26.4s, #0x0\n"
546-
"sub x20, x22, #0x2\n"
547-
"ldr q25, [x22, #0x10]\n"
548-
"ldr q24, [%x[b_ptr], #0x10]\n"
549-
"sub x21, x21, #0x1\n"
550-
"add x22, x22, #0x22\n"
551-
"ldr q23, [%x[b_ptr], #0x20]\n"
552-
"ldr q22, [%x[b_ptr], #0x30]\n"
553-
"ld1r { v21.8h }, [x20]\n"
554-
"ldr q20, [%x[b_ptr], #-0x8]\n"
555-
"sshl v16.16b, v28.16b, v31.16b\n"
556-
"and v28.16b, v28.16b, v30.16b\n"
557-
"sshl v19.16b, v24.16b, v31.16b\n"
558-
"and v24.16b, v24.16b, v30.16b\n"
559-
"add %x[b_ptr], %x[b_ptr], #0x48\n"
560-
"sshl v18.16b, v23.16b, v31.16b\n"
561-
"and v23.16b, v23.16b, v30.16b\n"
562-
".inst 0x4f9be21a // sdot v26.4s, v16.16b, v27.4b[0]\n"
563-
"sshl v17.16b, v22.16b, v31.16b\n"
564-
"and v22.16b, v22.16b, v30.16b\n"
565-
"fcvtl v21.4s, v21.4h\n"
566-
"fcvtl v16.4s, v20.4h\n"
567-
".inst 0x4f99e39a // sdot v26.4s, v28.16b, v25.4b[0]\n"
568-
"fmul v16.4s, v16.4s, v21.4s\n"
569-
".inst 0x4fbbe27a // sdot v26.4s, v19.16b, v27.4b[1]\n"
570-
".inst 0x4fb9e31a // sdot v26.4s, v24.16b, v25.4b[1]\n"
571-
".inst 0x4f9bea5a // sdot v26.4s, v18.16b, v27.4b[2]\n"
572-
".inst 0x4f99eafa // sdot v26.4s, v23.16b, v25.4b[2]\n"
573-
".inst 0x4fbbea3a // sdot v26.4s, v17.16b, v27.4b[3]\n"
574-
".inst 0x4fb9eada // sdot v26.4s, v22.16b, v25.4b[3]\n"
575-
"scvtf v26.4s, v26.4s, #0x4\n"
576-
"fmla v29.4s, v26.4s, v16.4s\n"
577-
"cbnz x21, 2b\n"
578-
"sub %x[nc], %x[nc], #0x4\n"
579-
"str q29, [%x[res_ptr], #0x0]\n"
580-
"add %x[res_ptr], %x[res_ptr], #0x10\n"
581-
"cbnz %x[nc], 1b\n"
582-
: [b_ptr] "+&r" (b_ptr), [res_ptr] "+&r" (res_ptr), [nc] "+&r" (nc)
583-
: [a_ptr] "r" (a_ptr), [nb] "r" (nb)
584-
: "memory", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x20", "x21", "x22"
585-
);
530+
const block_q4_0x4 * b_ptr = (const block_q4_0x4 *)vx;
531+
532+
for (int c = 0; c < nc; c += ncols_interleaved) {
533+
const block_q8_0 * a_ptr = (const block_q8_0 *)vy;
534+
float32x4_t acc = vdupq_n_f32(0);
535+
for (int b = 0; b < nb; b++) {
536+
int8x16_t b0 = vld1q_s8((const int8_t *)b_ptr->qs);
537+
int8x16_t b1 = vld1q_s8((const int8_t *)b_ptr->qs + 16);
538+
int8x16_t b2 = vld1q_s8((const int8_t *)b_ptr->qs + 32);
539+
int8x16_t b3 = vld1q_s8((const int8_t *)b_ptr->qs + 48);
540+
float16x4_t bd = vld1_f16((const __fp16 *)b_ptr->d);
541+
542+
int8x16_t a0 = vld1q_s8(a_ptr->qs);
543+
int8x16_t a1 = vld1q_s8(a_ptr->qs + qk/2);
544+
float16x4_t ad = vld1_dup_f16((const __fp16 *)&a_ptr->d);
545+
546+
int32x4_t ret = vdupq_n_s32(0);
547+
548+
ret = vdotq_laneq_s32(ret, b0 << 4, a0, 0);
549+
ret = vdotq_laneq_s32(ret, b1 << 4, a0, 1);
550+
ret = vdotq_laneq_s32(ret, b2 << 4, a0, 2);
551+
ret = vdotq_laneq_s32(ret, b3 << 4, a0, 3);
552+
553+
ret = vdotq_laneq_s32(ret, b0 & 0xf0U, a1, 0);
554+
ret = vdotq_laneq_s32(ret, b1 & 0xf0U, a1, 1);
555+
ret = vdotq_laneq_s32(ret, b2 & 0xf0U, a1, 2);
556+
ret = vdotq_laneq_s32(ret, b3 & 0xf0U, a1, 3);
557+
558+
acc = vfmaq_f32(acc, vcvtq_n_f32_s32(ret, 4),
559+
vmulq_f32(vcvt_f32_f16(ad), vcvt_f32_f16(bd)));
560+
a_ptr++;
561+
b_ptr++;
562+
}
563+
vst1q_f32(s, acc);
564+
s += ncols_interleaved;
565+
}
586566
return;
587567
}
588-
#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON)
568+
#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
589569
float sumf[4];
590570
int sumi;
591571

0 commit comments

Comments
 (0)