Skip to content

Commit 8dbd7e7

Browse files
committed
ggml : remove Q5_0 bit shuffling (ARM NEON)
1 parent e5826a4 commit 8dbd7e7

File tree

1 file changed

+91
-107
lines changed

1 file changed

+91
-107
lines changed

ggml.c

Lines changed: 91 additions & 107 deletions
Original file line numberDiff line numberDiff line change
@@ -619,39 +619,6 @@ static inline const uint8_t * bytes_from_nibbles_64(const int qk, const uint8_t
619619
return (const uint8_t *) qd;
620620
}
621621

622-
// pack first half of weights into low nibbles and second half into high nibbles
623-
// use one scaling factor
624-
static inline void nibbles_from_floats_64_0(const int qk, const float * x, float id, uint8_t * qs, uint64_t * qd) {
625-
for (int l = 0; l < qk/2; ++l) {
626-
const float v0 = x[0 + l]*id;
627-
const float v1 = x[qk/2 + l]*id;
628-
629-
const uint64_t vi0 = MIN(15, (int8_t)(v0 + 8.5f));
630-
const uint64_t vi1 = MIN(15, (int8_t)(v1 + 8.5f));
631-
632-
qd[l/8] |= vi0 << (8*(l & 7));
633-
qd[l/8] |= vi1 << (8*(l & 7) + 4);
634-
}
635-
636-
memcpy(qs, qd, qk/2);
637-
}
638-
639-
// use offset and scaling factor
640-
static inline void nibbles_from_floats_64_1(const int qk, const float * x, float id, float min, uint8_t * qs, uint64_t * qd) {
641-
for (int l = 0; l < qk/2; ++l) {
642-
const float v0 = (x[0 + l] - min)*id;
643-
const float v1 = (x[qk/2 + l] - min)*id;
644-
645-
const uint64_t vi0 = MIN(15, (int8_t)(v0 + 0.5f));
646-
const uint64_t vi1 = MIN(15, (int8_t)(v1 + 0.5f));
647-
648-
qd[l/8] |= vi0 << (8*(l & 7));
649-
qd[l/8] |= vi1 << (8*(l & 7) + 4);
650-
}
651-
652-
memcpy(qs, qd, qk/2);
653-
}
654-
655622
#if !defined(__aarch64__)
656623

657624
inline static uint16_t vaddvq_u8(uint8x16_t v) {
@@ -900,7 +867,18 @@ static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * r
900867

901868
uint64_t qs[QK4_0 / 16] = {0};
902869

903-
nibbles_from_floats_64_0(qk, x + i*qk, id, y[i].qs, qs);
870+
for (int l = 0; l < qk/2; ++l) {
871+
const float v0 = x[i*qk + 0 + l]*id;
872+
const float v1 = x[i*qk + qk/2 + l]*id;
873+
874+
const uint64_t vi0 = MIN(15, (int8_t)(v0 + 8.5f));
875+
const uint64_t vi1 = MIN(15, (int8_t)(v1 + 8.5f));
876+
877+
qs[l/8] |= vi0 << (8*(l & 7));
878+
qs[l/8] |= vi1 << (8*(l & 7) + 4);
879+
}
880+
881+
memcpy(y[i].qs, qs, qk/2);
904882
}
905883
}
906884

@@ -935,7 +913,18 @@ static void quantize_row_q4_1_reference(const float * restrict x, block_q4_1 * r
935913

936914
uint64_t qs[QK4_1 / 16] = {0};
937915

938-
nibbles_from_floats_64_1(qk, x + i*qk, id, min, y[i].qs, qs);
916+
for (int l = 0; l < qk/2; ++l) {
917+
const float v0 = (x[0 + l] - min)*id;
918+
const float v1 = (x[qk/2 + l] - min)*id;
919+
920+
const uint64_t vi0 = MIN(15, (int8_t)(v0 + 0.5f));
921+
const uint64_t vi1 = MIN(15, (int8_t)(v1 + 0.5f));
922+
923+
qs[l/8] |= vi0 << (8*(l & 7));
924+
qs[l/8] |= vi1 << (8*(l & 7) + 4);
925+
}
926+
927+
memcpy(y[i].qs, qs, qk/2);
939928
}
940929
}
941930

@@ -971,7 +960,18 @@ static void quantize_row_q4_2_reference(const float * restrict x, block_q4_2 * r
971960

972961
uint64_t qs[QK4_2 / 16] = {0};
973962

974-
nibbles_from_floats_64_0(qk, x + i*qk, id, y[i].qs, qs);
963+
for (int l = 0; l < qk/2; ++l) {
964+
const float v0 = x[i*qk + 0 + l]*id;
965+
const float v1 = x[i*qk + qk/2 + l]*id;
966+
967+
const uint64_t vi0 = MIN(15, (int8_t)(v0 + 8.5f));
968+
const uint64_t vi1 = MIN(15, (int8_t)(v1 + 8.5f));
969+
970+
qs[l/8] |= vi0 << (8*(l & 7));
971+
qs[l/8] |= vi1 << (8*(l & 7) + 4);
972+
}
973+
974+
memcpy(y[i].qs, qs, qk/2);
975975
}
976976
}
977977

@@ -980,51 +980,54 @@ static void quantize_row_q4_2(const float * restrict x, void * restrict y, int k
980980
}
981981

982982
static void quantize_row_q5_0_reference(const float * restrict x, block_q5_0 * restrict y, int k) {
983-
assert(k % QK5_0 == 0);
984-
const int nb = k / QK5_0;
983+
static const int qk = QK5_0;
984+
985+
assert(qk / 16 == 0);
986+
assert( k % qk == 0);
987+
988+
const int nb = k / qk;
985989

986990
for (int i = 0; i < nb; i++) {
987991
float amax = 0.0f; // absolute max
988-
float max = 0.0f;
992+
float max = 0.0f;
989993

990-
for (int l = 0; l < QK5_0; l++) {
991-
const float v = x[i*QK5_0 + l];
994+
for (int l = 0; l < qk; l++) {
995+
const float v = x[i*qk + l];
992996
if (amax < fabsf(v)) {
993997
amax = fabsf(v);
994-
max = v;
998+
max = v;
995999
}
9961000
}
9971001

998-
const float d = max / -16;
1002+
const float d = max / -16;
9991003
const float id = d ? 1.0f/d : 0.0f;
10001004

1001-
y[i].d = GGML_FP32_TO_FP16(d);
1005+
y[i].d = d;
10021006

10031007
uint32_t qh = 0;
1008+
uint64_t qs[QK5_0 / 16] = {0};
10041009

1005-
for (int l = 0; l < QK5_0; l += 2) {
1006-
const float v0 = x[i*QK5_0 + l + 0]*id;
1007-
const float v1 = x[i*QK5_0 + l + 1]*id;
1010+
for (int l = 0; l < qk/2; ++l) {
1011+
const float v0 = x[i*qk + 0 + l]*id;
1012+
const float v1 = x[i*qk + qk/2 + l]*id;
10081013

1009-
const uint32_t vi0 = MIN(31, (int) (v0 + 16.5f));
1010-
const uint32_t vi1 = MIN(31, (int) (v1 + 16.5f));
1014+
const uint64_t vi0 = MIN(31, (int8_t)(v0 + 16.5f));
1015+
const uint64_t vi1 = MIN(31, (int8_t)(v1 + 16.5f));
10111016

1012-
y[i].qs[l/2] = (vi0 & 0x0F) | ((vi1 & 0x0F) << 4);
1017+
qs[l/8] |= vi0 << (8*(l & 7));
1018+
qs[l/8] |= vi1 << (8*(l & 7) + 4);
10131019

10141020
// get the 5-th bit and store it in qh at the right position
10151021
qh |= ((vi0 & 0x10) >> 4) << (l + 0);
1016-
qh |= ((vi1 & 0x10) >> 4) << (l + 1);
1022+
qh |= ((vi1 & 0x10) >> 4) << (l + qk/2);
10171023
}
10181024

1019-
memcpy(&y[i].qh, &qh, sizeof(y[i].qh));
1025+
memcpy( y[i].qs, qs, qk/2);
1026+
memcpy(&y[i].qh, &qh, sizeof(qh));
10201027
}
10211028
}
10221029

1023-
static void quantize_row_q5_0(const float * restrict x, void * restrict vy, int k) {
1024-
assert(k % QK5_0 == 0);
1025-
1026-
block_q5_0 * restrict y = vy;
1027-
1030+
static void quantize_row_q5_0(const float * restrict x, void * restrict y, int k) {
10281031
quantize_row_q5_0_reference(x, y, k);
10291032
}
10301033

@@ -1373,38 +1376,28 @@ static void dequantize_row_q4_2(const block_q4_2 * restrict x, float * restrict
13731376
}
13741377
}
13751378

1376-
static void dequantize_row_q5_0(const void * restrict vx, float * restrict y, int k) {
1377-
assert(k % QK5_0 == 0);
1378-
const int nb = k / QK5_0;
1379+
static void dequantize_row_q5_0(const block_q5_0 * restrict x, float * restrict y, int k) {
1380+
static const int qk = QK4_0;
13791381

1380-
const block_q5_0 * restrict x = vx;
1382+
assert(qk / 16 == 0);
1383+
assert( k % qk == 0);
1384+
1385+
const int nb = k / qk;
1386+
1387+
uint64_t qs[QK5_0 / 8];
13811388

13821389
for (int i = 0; i < nb; i++) {
13831390
const float d = GGML_FP16_TO_FP32(x[i].d);
13841391

1385-
const uint8_t * restrict pp = x[i].qs;
1386-
13871392
uint32_t qh;
13881393
memcpy(&qh, x[i].qh, sizeof(qh));
13891394

1390-
for (int l = 0; l < QK5_0; l += 2) {
1391-
const uint8_t vi = pp[l/2];
1392-
1393-
// extract the 5-th bit from qh
1394-
const uint8_t vh0 = ((qh & (1u << (l + 0))) >> (l + 0)) << 4;
1395-
const uint8_t vh1 = ((qh & (1u << (l + 1))) >> (l + 1)) << 4;
1396-
1397-
const int8_t vi0 = (vi & 0x0F) | vh0;
1398-
const int8_t vi1 = (vi >> 4) | vh1;
1399-
1400-
const float v0 = (vi0 - 16)*d;
1401-
const float v1 = (vi1 - 16)*d;
1395+
const uint8_t * qsp = bytes_from_nibbles_64(qk, x[i].qs, qs);
14021396

1403-
y[i*QK5_0 + l + 0] = v0;
1404-
y[i*QK5_0 + l + 1] = v1;
1397+
for (int l = 0; l < qk; ++l) {
1398+
const uint8_t vh = ((qh & (1u << l)) >> l) << 4;
14051399

1406-
assert(!isnan(y[i*QK5_0 + l + 0]));
1407-
assert(!isnan(y[i*QK5_0 + l + 1]));
1400+
y[i*qk + l] = ((qsp[l] | vh) - 16)*d;
14081401
}
14091402
}
14101403
}
@@ -1496,7 +1489,7 @@ static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = {
14961489
.vec_dot_type = GGML_TYPE_Q8_0,
14971490
},
14981491
[GGML_TYPE_Q5_0] = {
1499-
.dequantize_row_q = dequantize_row_q5_0,
1492+
.dequantize_row_q = (dequantize_row_q_t) dequantize_row_q5_0,
15001493
.quantize_row_q = quantize_row_q5_0,
15011494
.quantize_row_q_reference = (quantize_row_q_t) quantize_row_q5_0_reference,
15021495
.quantize_row_q_dot = quantize_row_q8_0,
@@ -2566,11 +2559,12 @@ static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void *
25662559
}
25672560

25682561
static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
2569-
const int nb = n / QK8_0;
2562+
const int qk = QK8_0;
2563+
const int nb = n / qk;
25702564

2571-
assert(n % QK8_0 == 0);
2565+
assert(n % qk == 0);
25722566
assert(nb % 2 == 0);
2573-
assert(QK8_0 == QK5_0);
2567+
assert(qk == QK5_0);
25742568

25752569
const block_q5_0 * restrict x = vx;
25762570
const block_q8_0 * restrict y = vy;
@@ -2605,13 +2599,9 @@ static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void *
26052599
const int8x16_t v0l = vreinterpretq_s8_u8(vandq_u8 (v0, m4b));
26062600
const int8x16_t v0h = vreinterpretq_s8_u8(vshrq_n_u8(v0, 4));
26072601

2608-
// interleave
2609-
const int8x16_t v0lz = vzip1q_s8(v0l, v0h);
2610-
const int8x16_t v0hz = vzip2q_s8(v0l, v0h);
2611-
26122602
// add high bit and sub 16
2613-
const int8x16_t v0lf = vsubq_s8(vorrq_s8(v0lz, qhl), s16b);
2614-
const int8x16_t v0hf = vsubq_s8(vorrq_s8(v0hz, qhh), s16b);
2603+
const int8x16_t v0lf = vsubq_s8(vorrq_s8(v0l, qhl), s16b);
2604+
const int8x16_t v0hf = vsubq_s8(vorrq_s8(v0h, qhh), s16b);
26152605

26162606
// load y
26172607
const int8x16_t v1l = vld1q_s8(y0->qs);
@@ -2729,34 +2719,28 @@ static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void *
27292719
#else
27302720
// scalar
27312721
float sumf = 0.0;
2722+
2723+
uint64_t qs[QK8_0 / 8];
2724+
27322725
for (int i = 0; i < nb; i++) {
2733-
const uint8_t * restrict x0 = x[i].qs;
2734-
const int8_t * restrict y0 = y[i].qs;
2726+
// unpack nibbles into bytes
2727+
const uint8_t * px = bytes_from_nibbles_64(qk, x[i].qs, qs);
2728+
const int8_t * py = y[i].qs;
27352729

27362730
uint32_t qh;
27372731
memcpy(&qh, x[i].qh, sizeof(qh));
27382732

2739-
const float d = GGML_FP16_TO_FP32(x[i].d);
2740-
2741-
int sxy = 0;
2742-
2743-
for (int j = 0; j < QK8_0/2; j++) {
2744-
const uint8_t v0 = x0[j];
2745-
2746-
const int x0_0h = ((qh & (1u << (2*j + 0))) >> (2*j + 0)) << 4;
2747-
const int x1_0h = ((qh & (1u << (2*j + 1))) >> (2*j + 1)) << 4;
2748-
2749-
const int x0_0 = ((v0 & 0x0F) | x0_0h) - 16;
2750-
const int x1_0 = ((v0 >> 4) | x1_0h) - 16;
2733+
int sumi = 0;
27512734

2752-
const int y0_0 = y0[2*j + 0];
2753-
const int y1_0 = y0[2*j + 1];
2735+
for (int j = 0; j < qk; ++j) {
2736+
const int xh = ((qh & (1u << j)) >> j) << 4;
27542737

2755-
sxy += x0_0*y0_0 + x1_0*y1_0;
2738+
sumi += ((px[j] | xh) - 16)*py[j];
27562739
}
27572740

2758-
sumf += (d*sxy)*y[i].d;
2741+
sumf += (GGML_FP16_TO_FP32(x[i].d)*y[i].d)*sumi;
27592742
}
2743+
27602744
*s = sumf;
27612745
#endif
27622746
}

0 commit comments

Comments
 (0)