Skip to content

Commit cbd8a1a

Browse files
committed
Update quantize_row_q4_0 for Arm NEON
Untested
1 parent 6e64a6a commit cbd8a1a

File tree

1 file changed

+16
-13
lines changed

1 file changed

+16
-13
lines changed

ggml.c

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -604,22 +604,24 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
604604
#elif __ARM_NEON
605605
for (int i = 0; i < nb; i++) {
606606
float32x4_t srcv [8];
607-
float32x4_t asrcv[8];
608-
float32x4_t amaxv[8];
607+
float32x4_t maxv[8];
608+
float32x4_t minv[8];
609609

610610
for (int l = 0; l < 8; l++) srcv[l] = vld1q_f32(x + i*32 + 4*l);
611-
for (int l = 0; l < 8; l++) asrcv[l] = vabsq_f32(srcv[l]);
612611

613-
for (int l = 0; l < 4; l++) amaxv[2*l] = vmaxq_f32(asrcv[2*l], asrcv[2*l+1]);
614-
for (int l = 0; l < 2; l++) amaxv[4*l] = vmaxq_f32(amaxv[4*l], amaxv[4*l+2]);
615-
for (int l = 0; l < 1; l++) amaxv[8*l] = vmaxq_f32(amaxv[8*l], amaxv[8*l+4]);
612+
for (int l = 0; l < 4; l++) maxv[2*l] = vmaxq_f32(srcv[2*l], srcv[2*l+1]);
613+
for (int l = 0; l < 2; l++) maxv[4*l] = vmaxq_f32(maxv[4*l], maxv[4*l+2]);
614+
for (int l = 0; l < 1; l++) maxv[8*l] = vmaxq_f32(maxv[8*l], maxv[8*l+4]);
616615

617-
// absolute max
618-
const float amax = MAX(
619-
MAX(vgetq_lane_f32(amaxv[0], 0), vgetq_lane_f32(amaxv[0], 1)),
620-
MAX(vgetq_lane_f32(amaxv[0], 2), vgetq_lane_f32(amaxv[0], 3)));
616+
for (int l = 0; l < 4; l++) minv[2*l] = vminq_f32(srcv[2*l], srcv[2*l+1]);
617+
for (int l = 0; l < 2; l++) minv[4*l] = vminq_f32(minv[4*l], minv[4*l+2]);
618+
for (int l = 0; l < 1; l++) minv[8*l] = vminq_f32(minv[8*l], minv[8*l+4]);
621619

622-
const float d = amax / ((1 << 3) - 1);
620+
const float max = vmaxvq_f32(maxv[0]);
621+
const float min = vminvq_f32(minv[0]);
622+
623+
const float magnitude = max >= fabsf(min) ? max : min;
624+
const float d = magnitude / -8;
623625
const float id = d ? 1.0f/d : 0.0f;
624626

625627
y[i].d = d;
@@ -628,9 +630,10 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
628630
const float32x4_t v = vmulq_n_f32(srcv[l], id);
629631
const float32x4_t vf = vaddq_f32(v, vdupq_n_f32(8.5f));
630632
const int32x4_t vi = vcvtq_s32_f32(vf);
633+
const int32x4 vc = vminq_u32(vi, vdupq_n_u32(15));
631634

632-
y[i].qs[2*l + 0] = vgetq_lane_s32(vi, 0) | (vgetq_lane_s32(vi, 1) << 4);
633-
y[i].qs[2*l + 1] = vgetq_lane_s32(vi, 2) | (vgetq_lane_s32(vi, 3) << 4);
635+
y[i].qs[2*l + 0] = vgetq_lane_s32(vc, 0) | (vgetq_lane_s32(vc, 1) << 4);
636+
y[i].qs[2*l + 1] = vgetq_lane_s32(vc, 2) | (vgetq_lane_s32(vc, 3) << 4);
634637
}
635638
}
636639
#elif defined(__AVX2__)

0 commit comments

Comments
 (0)