Skip to content

Commit b89a0ba

Browse files
committed
iq1_m: faster ARM_NEON dot product
11.65 t/s -> 14.9 t/s
1 parent eada2a7 commit b89a0ba

File tree

1 file changed

+10
-6
lines changed

1 file changed

+10
-6
lines changed

ggml-quants.c

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9763,12 +9763,16 @@ void ggml_vec_dot_iq1_m_q8_K (int n, float * restrict s, size_t bs, const void
97639763

97649764
#if defined __ARM_NEON
97659765

9766-
const int8x8_t minus1 = vdup_n_s8(-1);
9767-
const int8x8_t plus1 = vdup_n_s8(+1);
97689766
const int32x4_t mask = vdupq_n_s32(0x7);
97699767
const int32x4_t mone = vdupq_n_s32(1);
97709768
const int32x4_t mzero = vdupq_n_s32(0);
97719769

9770+
ggml_int8x16x4_t deltas;
9771+
deltas.val[0] = vcombine_s8(vdup_n_s8(+1), vdup_n_s8(+1));
9772+
deltas.val[1] = vcombine_s8(vdup_n_s8(-1), vdup_n_s8(+1));
9773+
deltas.val[2] = vcombine_s8(vdup_n_s8(+1), vdup_n_s8(-1));
9774+
deltas.val[3] = vcombine_s8(vdup_n_s8(-1), vdup_n_s8(-1));
9775+
97729776
ggml_int8x16x4_t q1b;
97739777
ggml_int8x16x4_t q8b;
97749778
ggml_int8x16x4_t delta;
@@ -9805,10 +9809,10 @@ void ggml_vec_dot_iq1_m_q8_K (int n, float * restrict s, size_t bs, const void
98059809
const int32x4_t p2 = vpaddq_s32(ggml_vdotq_s32(mzero, q1b.val[2], q8b.val[2]), ggml_vdotq_s32(mzero, q1b.val[3], q8b.val[3]));
98069810
const int32x4_t p12 = vpaddq_s32(p1, p2);
98079811

9808-
delta.val[0] = vcombine_s8(qh[0] & 0x08 ? minus1 : plus1, qh[0] & 0x80 ? minus1 : plus1);
9809-
delta.val[1] = vcombine_s8(qh[1] & 0x08 ? minus1 : plus1, qh[1] & 0x80 ? minus1 : plus1);
9810-
delta.val[2] = vcombine_s8(qh[2] & 0x08 ? minus1 : plus1, qh[2] & 0x80 ? minus1 : plus1);
9811-
delta.val[3] = vcombine_s8(qh[3] & 0x08 ? minus1 : plus1, qh[3] & 0x80 ? minus1 : plus1);
9812+
delta.val[0] = deltas.val[((qh[0] & 0x08) >> 3) | ((qh[0] & 0x80) >> 6)];
9813+
delta.val[1] = deltas.val[((qh[1] & 0x08) >> 3) | ((qh[1] & 0x80) >> 6)];
9814+
delta.val[2] = deltas.val[((qh[2] & 0x08) >> 3) | ((qh[2] & 0x80) >> 6)];
9815+
delta.val[3] = deltas.val[((qh[3] & 0x08) >> 3) | ((qh[3] & 0x80) >> 6)];
98129816

98139817
const int32x4_t p3 = vpaddq_s32(ggml_vdotq_s32(mzero, delta.val[0], q8b.val[0]), ggml_vdotq_s32(mzero, delta.val[1], q8b.val[1]));
98149818
const int32x4_t p4 = vpaddq_s32(ggml_vdotq_s32(mzero, delta.val[2], q8b.val[2]), ggml_vdotq_s32(mzero, delta.val[3], q8b.val[3]));

0 commit comments

Comments
 (0)