Skip to content

Commit 34cdd78

Browse files
committed
ggml : nibbles_from_floats() + bytes_from_nibbles() (ARM NEON)
1 parent c6a3266 commit 34cdd78

File tree

1 file changed

+60
-71
lines changed

1 file changed

+60
-71
lines changed

ggml.c

Lines changed: 60 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -608,6 +608,50 @@ static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 )
608608

609609
#if __ARM_NEON
610610

611+
static inline const uint8_t * bytes_from_nibbles_64(const int qk, const uint8_t * qs, uint64_t * qd) {
612+
memcpy(qd, qs, qk/2);
613+
614+
for (int l = 0; l < qk/16; ++l) {
615+
qd[l + qk/16] = (qd[l] & 0xF0F0F0F0F0F0F0F0ULL) >> 4;
616+
qd[l + 0 ] = (qd[l] & 0x0F0F0F0F0F0F0F0FULL) >> 0;
617+
}
618+
619+
return (const uint8_t *) qd;
620+
}
621+
622+
// pack first half of weights into low nibbles and second half into high nibbles
623+
// use one scaling factor
624+
static inline void nibbles_from_floats_64_0(const int qk, const float * x, float id, uint8_t * qs, uint64_t * qd) {
625+
for (int l = 0; l < qk/2; ++l) {
626+
const float v0 = x[0 + l]*id;
627+
const float v1 = x[qk/2 + l]*id;
628+
629+
const uint64_t vi0 = MIN(15, (int8_t)(v0 + 8.5f));
630+
const uint64_t vi1 = MIN(15, (int8_t)(v1 + 8.5f));
631+
632+
qd[l/8] |= vi0 << (8*(l & 7));
633+
qd[l/8] |= vi1 << (8*(l & 7) + 4);
634+
}
635+
636+
memcpy(qs, qd, qk/2);
637+
}
638+
639+
// use offset and scaling factor
640+
static inline void nibbles_from_floats_64_1(const int qk, const float * x, float id, float min, uint8_t * qs, uint64_t * qd) {
641+
for (int l = 0; l < qk/2; ++l) {
642+
const float v0 = (x[0 + l] - min)*id;
643+
const float v1 = (x[qk/2 + l] - min)*id;
644+
645+
const uint64_t vi0 = MIN(15, (int8_t)(v0 + 0.5f));
646+
const uint64_t vi1 = MIN(15, (int8_t)(v1 + 0.5f));
647+
648+
qd[l/8] |= vi0 << (8*(l & 7));
649+
qd[l/8] |= vi1 << (8*(l & 7) + 4);
650+
}
651+
652+
memcpy(qs, qd, qk/2);
653+
}
654+
611655
#if !defined(__aarch64__)
612656

613657
inline static uint16_t vaddvq_u8(uint8x16_t v) {
@@ -856,19 +900,7 @@ static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * r
856900

857901
uint64_t qs[QK4_0 / 16] = {0};
858902

859-
// pack first half of weights into low nibbles and second half into high nibbles
860-
for (int l = 0; l < qk/2; ++l) {
861-
const float v0 = x[i*qk + 0 + l]*id;
862-
const float v1 = x[i*qk + qk/2 + l]*id;
863-
864-
const uint64_t vi0 = MIN(15, (int8_t)(v0 + 8.5f));
865-
const uint64_t vi1 = MIN(15, (int8_t)(v1 + 8.5f));
866-
867-
qs[l/8] |= vi0 << (8*(l & 7));
868-
qs[l/8] |= vi1 << (8*(l & 7) + 4);
869-
}
870-
871-
memcpy(y[i].qs, qs, sizeof(qs));
903+
nibbles_from_floats_64_0(qk, x + i*qk, id, y[i].qs, qs);
872904
}
873905
}
874906

@@ -903,19 +935,7 @@ static void quantize_row_q4_1_reference(const float * restrict x, block_q4_1 * r
903935

904936
uint64_t qs[QK4_1 / 16] = {0};
905937

906-
// pack first half of weights into low nibbles and second half into high nibbles
907-
for (int l = 0; l < qk/2; ++l) {
908-
const float v0 = (x[i*qk + 0 + l] - min)*id;
909-
const float v1 = (x[i*qk + qk/2 + l] - min)*id;
910-
911-
const uint64_t vi0 = MIN(15, (int8_t)(v0 + 0.5f));
912-
const uint64_t vi1 = MIN(15, (int8_t)(v1 + 0.5f));
913-
914-
qs[l/8] |= vi0 << (8*(l & 7));
915-
qs[l/8] |= vi1 << (8*(l & 7) + 4);
916-
}
917-
918-
memcpy(y[i].qs, qs, sizeof(qs));
938+
nibbles_from_floats_64_1(qk, x + i*qk, id, min, y[i].qs, qs);
919939
}
920940
}
921941

@@ -1308,20 +1328,12 @@ static void dequantize_row_q4_0(const block_q4_0 * restrict x, float * restrict
13081328

13091329
const int nb = k / qk;
13101330

1331+
uint64_t qs[QK4_0 / 8];
1332+
13111333
for (int i = 0; i < nb; i++) {
13121334
const float d = x[i].d;
13131335

1314-
// unpack nibbles into bytes
1315-
uint64_t qs[QK4_0 / 8] = {0};
1316-
1317-
memcpy(qs + 0, x[i].qs, sizeof(x[i].qs));
1318-
1319-
for (int l = 0; l < qk / 16; ++l) {
1320-
qs[l + qk/16] = (qs[l] & 0xF0F0F0F0F0F0F0F0ULL) >> 4;
1321-
qs[l + 0 ] = (qs[l] & 0x0F0F0F0F0F0F0F0FULL) >> 0;
1322-
}
1323-
1324-
const uint8_t * restrict qsp = (const uint8_t * restrict) qs;
1336+
const uint8_t * qsp = bytes_from_nibbles_64(qk, x[i].qs, qs);
13251337

13261338
for (int l = 0; l < qk; ++l) {
13271339
y[i*qk + l] = (qsp[l] - 8)*d;
@@ -1337,21 +1349,13 @@ static void dequantize_row_q4_1(const block_q4_1 * restrict x, float * restrict
13371349

13381350
const int nb = k / qk;
13391351

1352+
uint64_t qs[QK4_0 / 8];
1353+
13401354
for (int i = 0; i < nb; i++) {
13411355
const float d = x[i].d;
13421356
const float m = x[i].m;
13431357

1344-
// unpack nibbles into bytes
1345-
uint64_t qs[QK4_0 / 8] = {0};
1346-
1347-
memcpy(qs + 0, x[i].qs, sizeof(x[i].qs));
1348-
1349-
for (int l = 0; l < qk / 16; ++l) {
1350-
qs[l + qk/16] = (qs[l] & 0xF0F0F0F0F0F0F0F0ULL) >> 4;
1351-
qs[l + 0 ] = (qs[l] & 0x0F0F0F0F0F0F0F0FULL) >> 0;
1352-
}
1353-
1354-
const uint8_t * restrict qsp = (const uint8_t * restrict) qs;
1358+
const uint8_t * qsp = bytes_from_nibbles_64(qk, x[i].qs, qs);
13551359

13561360
for (int l = 0; l < qk; ++l) {
13571361
y[i*qk + l] = qsp[l]*d + m;
@@ -2283,19 +2287,12 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
22832287
// scalar
22842288
float sumf = 0.0;
22852289

2290+
uint64_t qs[QK8_0 / 8];
2291+
22862292
for (int i = 0; i < nb; i++) {
22872293
// unpack nibbles into bytes
2288-
uint64_t qs[QK8_0 / 8] = {0};
2289-
2290-
memcpy(qs + 0, x[i].qs, sizeof(x[i].qs));
2291-
2292-
for (int l = 0; l < qk / 16; ++l) {
2293-
qs[l + qk/16] = (qs[l] & 0xF0F0F0F0F0F0F0F0ULL) >> 4;
2294-
qs[l + 0 ] = (qs[l] & 0x0F0F0F0F0F0F0F0FULL) >> 0;
2295-
}
2296-
2297-
const uint8_t * restrict px = (const uint8_t * restrict) qs;
2298-
const int8_t * restrict py = y[i].qs;
2294+
const uint8_t * px = bytes_from_nibbles_64(qk, x[i].qs, qs);
2295+
const int8_t * py = y[i].qs;
22992296

23002297
int sumi = 0;
23012298

@@ -2415,19 +2412,11 @@ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void *
24152412
// scalar
24162413
float sumf = 0.0;
24172414

2418-
for (int i = 0; i < nb; i++) {
2419-
// unpack nibbles into bytes
2420-
uint64_t qs[QK8_1 / 8] = {0};
2421-
2422-
memcpy(qs + 0, x[i].qs, sizeof(x[i].qs));
2415+
uint64_t qs[QK8_1 / 8];
24232416

2424-
for (int l = 0; l < qk / 16; ++l) {
2425-
qs[l + qk/16] = (qs[l] & 0xF0F0F0F0F0F0F0F0ULL) >> 4;
2426-
qs[l + 0 ] = (qs[l] & 0x0F0F0F0F0F0F0F0FULL) >> 0;
2427-
}
2428-
2429-
const uint8_t * restrict px = (const uint8_t * restrict) qs;
2430-
const int8_t * restrict py = y[i].qs;
2417+
for (int i = 0; i < nb; i++) {
2418+
const uint8_t * px = bytes_from_nibbles_64(qk, x[i].qs, qs);
2419+
const int8_t * py = y[i].qs;
24312420

24322421
int sumi = 0;
24332422

0 commit comments

Comments
 (0)