Skip to content

Commit 39bb8e7

Browse files
committed
ggml : 2x faster scalar implementations
1 parent 796f8ae commit 39bb8e7

File tree

1 file changed

+67
-65
lines changed

1 file changed

+67
-65
lines changed

ggml.c

Lines changed: 67 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -615,7 +615,8 @@ static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 )
615615

616616
#if __ARM_NEON
617617

618-
static inline const uint8_t * bytes_from_nibbles_64(const int qk, const uint8_t * qs, uint64_t * qd) {
618+
// TODO: obosolete - will be removed
619+
static inline const uint8_t * b4_from_nibbles_64(const int qk, const uint8_t * qs, uint64_t * qd) {
619620
memcpy(qd, qs, qk/2);
620621

621622
for (int l = 0; l < qk/16; ++l) {
@@ -875,14 +876,14 @@ static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * r
875876
uint64_t qs[QK4_0 / 16] = {0};
876877

877878
for (int l = 0; l < qk/2; ++l) {
878-
const float v0 = x[i*qk + 0 + l]*id;
879-
const float v1 = x[i*qk + qk/2 + l]*id;
879+
const float x0 = x[i*qk + 0 + l]*id;
880+
const float x1 = x[i*qk + qk/2 + l]*id;
880881

881-
const uint64_t vi0 = MIN(15, (int8_t)(v0 + 8.5f));
882-
const uint64_t vi1 = MIN(15, (int8_t)(v1 + 8.5f));
882+
const uint64_t xi0 = MIN(15, (int8_t)(x0 + 8.5f));
883+
const uint64_t xi1 = MIN(15, (int8_t)(x1 + 8.5f));
883884

884-
qs[l/8] |= vi0 << (8*(l & 7));
885-
qs[l/8] |= vi1 << (8*(l & 7) + 4);
885+
qs[l/8] |= xi0 << (8*(l & 7));
886+
qs[l/8] |= xi1 << (8*(l & 7) + 4);
886887
}
887888

888889
memcpy(y[i].qs, qs, qk/2);
@@ -921,14 +922,14 @@ static void quantize_row_q4_1_reference(const float * restrict x, block_q4_1 * r
921922
uint64_t qs[QK4_1 / 16] = {0};
922923

923924
for (int l = 0; l < qk/2; ++l) {
924-
const float v0 = (x[0 + l] - min)*id;
925-
const float v1 = (x[qk/2 + l] - min)*id;
925+
const float x0 = (x[0 + l] - min)*id;
926+
const float x1 = (x[qk/2 + l] - min)*id;
926927

927-
const uint64_t vi0 = MIN(15, (int8_t)(v0 + 0.5f));
928-
const uint64_t vi1 = MIN(15, (int8_t)(v1 + 0.5f));
928+
const uint64_t xi0 = MIN(15, (int8_t)(x0 + 0.5f));
929+
const uint64_t xi1 = MIN(15, (int8_t)(x1 + 0.5f));
929930

930-
qs[l/8] |= vi0 << (8*(l & 7));
931-
qs[l/8] |= vi1 << (8*(l & 7) + 4);
931+
qs[l/8] |= xi0 << (8*(l & 7));
932+
qs[l/8] |= xi1 << (8*(l & 7) + 4);
932933
}
933934

934935
memcpy(y[i].qs, qs, qk/2);
@@ -968,14 +969,14 @@ static void quantize_row_q4_2_reference(const float * restrict x, block_q4_2 * r
968969
uint64_t qs[QK4_2 / 16] = {0};
969970

970971
for (int l = 0; l < qk/2; ++l) {
971-
const float v0 = x[i*qk + 0 + l]*id;
972-
const float v1 = x[i*qk + qk/2 + l]*id;
972+
const float x0 = x[i*qk + 0 + l]*id;
973+
const float x1 = x[i*qk + qk/2 + l]*id;
973974

974-
const uint64_t vi0 = MIN(15, (int8_t)(v0 + 8.5f));
975-
const uint64_t vi1 = MIN(15, (int8_t)(v1 + 8.5f));
975+
const uint64_t xi0 = MIN(15, (int8_t)(x0 + 8.5f));
976+
const uint64_t xi1 = MIN(15, (int8_t)(x1 + 8.5f));
976977

977-
qs[l/8] |= vi0 << (8*(l & 7));
978-
qs[l/8] |= vi1 << (8*(l & 7) + 4);
978+
qs[l/8] |= xi0 << (8*(l & 7));
979+
qs[l/8] |= xi1 << (8*(l & 7) + 4);
979980
}
980981

981982
memcpy(y[i].qs, qs, qk/2);
@@ -1015,18 +1016,18 @@ static void quantize_row_q5_0_reference(const float * restrict x, block_q5_0 * r
10151016
uint64_t qs[QK5_0 / 16] = {0};
10161017

10171018
for (int l = 0; l < qk/2; ++l) {
1018-
const float v0 = x[i*qk + 0 + l]*id;
1019-
const float v1 = x[i*qk + qk/2 + l]*id;
1019+
const float x0 = x[i*qk + 0 + l]*id;
1020+
const float x1 = x[i*qk + qk/2 + l]*id;
10201021

1021-
const uint64_t vi0 = MIN(31, (int8_t)(v0 + 16.5f));
1022-
const uint64_t vi1 = MIN(31, (int8_t)(v1 + 16.5f));
1022+
const uint64_t xi0 = MIN(31, (int8_t)(x0 + 16.5f));
1023+
const uint64_t xi1 = MIN(31, (int8_t)(x1 + 16.5f));
10231024

1024-
qs[l/8] |= vi0 << (8*(l & 7));
1025-
qs[l/8] |= vi1 << (8*(l & 7) + 4);
1025+
qs[l/8] |= xi0 << (8*(l & 7));
1026+
qs[l/8] |= xi1 << (8*(l & 7) + 4);
10261027

10271028
// get the 5-th bit and store it in qh at the right position
1028-
qh |= ((vi0 & 0x10) >> 4) << (l + 0);
1029-
qh |= ((vi1 & 0x10) >> 4) << (l + qk/2);
1029+
qh |= ((xi0 & 0x10) >> 4) << (l + 0);
1030+
qh |= ((xi1 & 0x10) >> 4) << (l + qk/2);
10301031
}
10311032

10321033
memcpy( y[i].qs, qs, qk/2);
@@ -1447,15 +1448,15 @@ static void dequantize_row_q4_0(const block_q4_0 * restrict x, float * restrict
14471448

14481449
const int nb = k / qk;
14491450

1450-
uint64_t qs[QK4_0 / 8];
1451-
14521451
for (int i = 0; i < nb; i++) {
14531452
const float d = x[i].d;
14541453

1455-
const uint8_t * qsp = bytes_from_nibbles_64(qk, x[i].qs, qs);
1454+
for (int j = 0; j < qk/2; ++j) {
1455+
const int x0 = (x[i].qs[j] & 0xf) - 8;
1456+
const int x1 = (x[i].qs[j] >> 4) - 8;
14561457

1457-
for (int l = 0; l < qk; ++l) {
1458-
y[i*qk + l] = (qsp[l] - 8)*d;
1458+
y[i*qk + j + 0 ] = x0*d;
1459+
y[i*qk + j + qk/2] = x1*d;
14591460
}
14601461
}
14611462
}
@@ -1468,21 +1469,22 @@ static void dequantize_row_q4_1(const block_q4_1 * restrict x, float * restrict
14681469

14691470
const int nb = k / qk;
14701471

1471-
uint64_t qs[QK4_0 / 8];
1472-
14731472
for (int i = 0; i < nb; i++) {
14741473
const float d = x[i].d;
14751474
const float m = x[i].m;
14761475

1477-
const uint8_t * qsp = bytes_from_nibbles_64(qk, x[i].qs, qs);
1476+
for (int j = 0; j < qk/2; ++j) {
1477+
const int x0 = (x[i].qs[j] & 0xf);
1478+
const int x1 = (x[i].qs[j] >> 4);
14781479

1479-
for (int l = 0; l < qk; ++l) {
1480-
y[i*qk + l] = qsp[l]*d + m;
1480+
y[i*qk + j + 0 ] = x0*d + m;
1481+
y[i*qk + j + qk/2] = x1*d + m;
14811482
}
14821483
}
14831484
}
14841485

14851486
static void dequantize_row_q4_2(const block_q4_2 * restrict x, float * restrict y, int k) {
1487+
// BORKEN !!!
14861488
static const int qk = QK4_2;
14871489

14881490
assert(qk / 16 == 0);
@@ -1495,7 +1497,7 @@ static void dequantize_row_q4_2(const block_q4_2 * restrict x, float * restrict
14951497
for (int i = 0; i < nb; i++) {
14961498
const float d = GGML_FP16_TO_FP32(x[i].d);
14971499

1498-
const uint8_t * qsp = bytes_from_nibbles_64(qk, x[i].qs, qs);
1500+
const uint8_t * qsp = b4_from_nibbles_64(qk, x[i].qs, qs);
14991501

15001502
for (int l = 0; l < qk; ++l) {
15011503
y[i*qk + l] = (qsp[l] - 8)*d;
@@ -1511,20 +1513,21 @@ static void dequantize_row_q5_0(const block_q5_0 * restrict x, float * restrict
15111513

15121514
const int nb = k / qk;
15131515

1514-
uint64_t qs[QK5_0 / 8];
1515-
15161516
for (int i = 0; i < nb; i++) {
15171517
const float d = GGML_FP16_TO_FP32(x[i].d);
15181518

15191519
uint32_t qh;
15201520
memcpy(&qh, x[i].qh, sizeof(qh));
15211521

1522-
const uint8_t * qsp = bytes_from_nibbles_64(qk, x[i].qs, qs);
1522+
for (int j = 0; j < qk/2; ++j) {
1523+
const uint8_t xh_0 = ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4;
1524+
const uint8_t xh_1 = ((qh & (1u << (j + 16))) >> (j + 12));
15231525

1524-
for (int l = 0; l < qk; ++l) {
1525-
const uint8_t vh = ((qh & (1u << l)) >> l) << 4;
1526+
const int32_t x0 = ((x[i].qs[j] & 0xf) | xh_0) - 16;
1527+
const int32_t x1 = ((x[i].qs[j] >> 4) | xh_1) - 16;
15261528

1527-
y[i*qk + l] = ((qsp[l] | vh) - 16)*d;
1529+
y[i*qk + j + 0 ] = x0*d;
1530+
y[i*qk + j + qk/2] = x1*d;
15281531
}
15291532
}
15301533
}
@@ -2388,17 +2391,16 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
23882391
// scalar
23892392
float sumf = 0.0;
23902393

2391-
uint64_t qs[QK8_0 / 8];
2392-
23932394
for (int i = 0; i < nb; i++) {
2394-
// unpack nibbles into bytes
2395-
const uint8_t * px = bytes_from_nibbles_64(qk, x[i].qs, qs);
2396-
const int8_t * py = y[i].qs;
2395+
const int8_t * py = y[i].qs;
23972396

23982397
int sumi = 0;
23992398

2400-
for (int j = 0; j < qk; ++j) {
2401-
sumi += (px[j] - 8) * py[j];
2399+
for (int j = 0; j < qk/2; ++j) {
2400+
const int v0 = (x[i].qs[j] & 0xf) - 8;
2401+
const int v1 = (x[i].qs[j] >> 4) - 8;
2402+
2403+
sumi += (v0 * py[j]) + (v1 * py[j + qk/2]);
24022404
}
24032405

24042406
sumf += (x[i].d*y[i].d)*sumi;
@@ -2513,16 +2515,16 @@ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void *
25132515
// scalar
25142516
float sumf = 0.0;
25152517

2516-
uint64_t qs[QK8_1 / 8];
2517-
25182518
for (int i = 0; i < nb; i++) {
2519-
const uint8_t * px = bytes_from_nibbles_64(qk, x[i].qs, qs);
2520-
const int8_t * py = y[i].qs;
2519+
const int8_t * py = y[i].qs;
25212520

25222521
int sumi = 0;
25232522

2524-
for (int j = 0; j < qk; ++j) {
2525-
sumi += px[j]*py[j];
2523+
for (int j = 0; j < qk/2; ++j) {
2524+
const int v0 = (x[i].qs[j] & 0xf);
2525+
const int v1 = (x[i].qs[j] >> 4);
2526+
2527+
sumi += (v0 * py[j]) + (v1 * py[j + qk/2]);
25262528
}
25272529

25282530
sumf += (x[i].d*y[i].d)*sumi + x[i].m*(y[i].s0 + y[i].s1);
@@ -2847,22 +2849,22 @@ static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void *
28472849
// scalar
28482850
float sumf = 0.0;
28492851

2850-
uint64_t qs[QK8_0 / 8];
2851-
28522852
for (int i = 0; i < nb; i++) {
2853-
// unpack nibbles into bytes
2854-
const uint8_t * px = bytes_from_nibbles_64(qk, x[i].qs, qs);
2855-
const int8_t * py = y[i].qs;
2853+
const int8_t * py = y[i].qs;
28562854

28572855
uint32_t qh;
28582856
memcpy(&qh, x[i].qh, sizeof(qh));
28592857

28602858
int sumi = 0;
28612859

2862-
for (int j = 0; j < qk; ++j) {
2863-
const int xh = ((qh & (1u << j)) >> j) << 4;
2860+
for (int j = 0; j < qk/2; ++j) {
2861+
const uint8_t xh_0 = ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4;
2862+
const uint8_t xh_1 = ((qh & (1u << (j + 16))) >> (j + 12));
2863+
2864+
const int32_t x0 = ((x[i].qs[j] & 0xf) | xh_0) - 16;
2865+
const int32_t x1 = ((x[i].qs[j] >> 4) | xh_1) - 16;
28642866

2865-
sumi += ((px[j] | xh) - 16)*py[j];
2867+
sumi += (x0 * py[j]) + (x1 * py[j + qk/2]);
28662868
}
28672869

28682870
sumf += (GGML_FP16_TO_FP32(x[i].d)*y[i].d)*sumi;

0 commit comments

Comments
 (0)