Skip to content

Commit b639b45

Browse files
committed
ggml : 2x faster scalar implementations
1 parent 8dbd7e7 commit b639b45

File tree

1 file changed

+67
-65
lines changed

1 file changed

+67
-65
lines changed

ggml.c

Lines changed: 67 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -608,7 +608,8 @@ static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 )
608608

609609
#if __ARM_NEON
610610

611-
static inline const uint8_t * bytes_from_nibbles_64(const int qk, const uint8_t * qs, uint64_t * qd) {
611+
// TODO: obosolete - will be removed
612+
static inline const uint8_t * b4_from_nibbles_64(const int qk, const uint8_t * qs, uint64_t * qd) {
612613
memcpy(qd, qs, qk/2);
613614

614615
for (int l = 0; l < qk/16; ++l) {
@@ -868,14 +869,14 @@ static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * r
868869
uint64_t qs[QK4_0 / 16] = {0};
869870

870871
for (int l = 0; l < qk/2; ++l) {
871-
const float v0 = x[i*qk + 0 + l]*id;
872-
const float v1 = x[i*qk + qk/2 + l]*id;
872+
const float x0 = x[i*qk + 0 + l]*id;
873+
const float x1 = x[i*qk + qk/2 + l]*id;
873874

874-
const uint64_t vi0 = MIN(15, (int8_t)(v0 + 8.5f));
875-
const uint64_t vi1 = MIN(15, (int8_t)(v1 + 8.5f));
875+
const uint64_t xi0 = MIN(15, (int8_t)(x0 + 8.5f));
876+
const uint64_t xi1 = MIN(15, (int8_t)(x1 + 8.5f));
876877

877-
qs[l/8] |= vi0 << (8*(l & 7));
878-
qs[l/8] |= vi1 << (8*(l & 7) + 4);
878+
qs[l/8] |= xi0 << (8*(l & 7));
879+
qs[l/8] |= xi1 << (8*(l & 7) + 4);
879880
}
880881

881882
memcpy(y[i].qs, qs, qk/2);
@@ -914,14 +915,14 @@ static void quantize_row_q4_1_reference(const float * restrict x, block_q4_1 * r
914915
uint64_t qs[QK4_1 / 16] = {0};
915916

916917
for (int l = 0; l < qk/2; ++l) {
917-
const float v0 = (x[0 + l] - min)*id;
918-
const float v1 = (x[qk/2 + l] - min)*id;
918+
const float x0 = (x[0 + l] - min)*id;
919+
const float x1 = (x[qk/2 + l] - min)*id;
919920

920-
const uint64_t vi0 = MIN(15, (int8_t)(v0 + 0.5f));
921-
const uint64_t vi1 = MIN(15, (int8_t)(v1 + 0.5f));
921+
const uint64_t xi0 = MIN(15, (int8_t)(x0 + 0.5f));
922+
const uint64_t xi1 = MIN(15, (int8_t)(x1 + 0.5f));
922923

923-
qs[l/8] |= vi0 << (8*(l & 7));
924-
qs[l/8] |= vi1 << (8*(l & 7) + 4);
924+
qs[l/8] |= xi0 << (8*(l & 7));
925+
qs[l/8] |= xi1 << (8*(l & 7) + 4);
925926
}
926927

927928
memcpy(y[i].qs, qs, qk/2);
@@ -961,14 +962,14 @@ static void quantize_row_q4_2_reference(const float * restrict x, block_q4_2 * r
961962
uint64_t qs[QK4_2 / 16] = {0};
962963

963964
for (int l = 0; l < qk/2; ++l) {
964-
const float v0 = x[i*qk + 0 + l]*id;
965-
const float v1 = x[i*qk + qk/2 + l]*id;
965+
const float x0 = x[i*qk + 0 + l]*id;
966+
const float x1 = x[i*qk + qk/2 + l]*id;
966967

967-
const uint64_t vi0 = MIN(15, (int8_t)(v0 + 8.5f));
968-
const uint64_t vi1 = MIN(15, (int8_t)(v1 + 8.5f));
968+
const uint64_t xi0 = MIN(15, (int8_t)(x0 + 8.5f));
969+
const uint64_t xi1 = MIN(15, (int8_t)(x1 + 8.5f));
969970

970-
qs[l/8] |= vi0 << (8*(l & 7));
971-
qs[l/8] |= vi1 << (8*(l & 7) + 4);
971+
qs[l/8] |= xi0 << (8*(l & 7));
972+
qs[l/8] |= xi1 << (8*(l & 7) + 4);
972973
}
973974

974975
memcpy(y[i].qs, qs, qk/2);
@@ -1008,18 +1009,18 @@ static void quantize_row_q5_0_reference(const float * restrict x, block_q5_0 * r
10081009
uint64_t qs[QK5_0 / 16] = {0};
10091010

10101011
for (int l = 0; l < qk/2; ++l) {
1011-
const float v0 = x[i*qk + 0 + l]*id;
1012-
const float v1 = x[i*qk + qk/2 + l]*id;
1012+
const float x0 = x[i*qk + 0 + l]*id;
1013+
const float x1 = x[i*qk + qk/2 + l]*id;
10131014

1014-
const uint64_t vi0 = MIN(31, (int8_t)(v0 + 16.5f));
1015-
const uint64_t vi1 = MIN(31, (int8_t)(v1 + 16.5f));
1015+
const uint64_t xi0 = MIN(31, (int8_t)(x0 + 16.5f));
1016+
const uint64_t xi1 = MIN(31, (int8_t)(x1 + 16.5f));
10161017

1017-
qs[l/8] |= vi0 << (8*(l & 7));
1018-
qs[l/8] |= vi1 << (8*(l & 7) + 4);
1018+
qs[l/8] |= xi0 << (8*(l & 7));
1019+
qs[l/8] |= xi1 << (8*(l & 7) + 4);
10191020

10201021
// get the 5-th bit and store it in qh at the right position
1021-
qh |= ((vi0 & 0x10) >> 4) << (l + 0);
1022-
qh |= ((vi1 & 0x10) >> 4) << (l + qk/2);
1022+
qh |= ((xi0 & 0x10) >> 4) << (l + 0);
1023+
qh |= ((xi1 & 0x10) >> 4) << (l + qk/2);
10231024
}
10241025

10251026
memcpy( y[i].qs, qs, qk/2);
@@ -1320,15 +1321,15 @@ static void dequantize_row_q4_0(const block_q4_0 * restrict x, float * restrict
13201321

13211322
const int nb = k / qk;
13221323

1323-
uint64_t qs[QK4_0 / 8];
1324-
13251324
for (int i = 0; i < nb; i++) {
13261325
const float d = x[i].d;
13271326

1328-
const uint8_t * qsp = bytes_from_nibbles_64(qk, x[i].qs, qs);
1327+
for (int j = 0; j < qk/2; ++j) {
1328+
const int x0 = (x[i].qs[j] & 0xf) - 8;
1329+
const int x1 = (x[i].qs[j] >> 4) - 8;
13291330

1330-
for (int l = 0; l < qk; ++l) {
1331-
y[i*qk + l] = (qsp[l] - 8)*d;
1331+
y[i*qk + j + 0 ] = x0*d;
1332+
y[i*qk + j + qk/2] = x1*d;
13321333
}
13331334
}
13341335
}
@@ -1341,21 +1342,22 @@ static void dequantize_row_q4_1(const block_q4_1 * restrict x, float * restrict
13411342

13421343
const int nb = k / qk;
13431344

1344-
uint64_t qs[QK4_0 / 8];
1345-
13461345
for (int i = 0; i < nb; i++) {
13471346
const float d = x[i].d;
13481347
const float m = x[i].m;
13491348

1350-
const uint8_t * qsp = bytes_from_nibbles_64(qk, x[i].qs, qs);
1349+
for (int j = 0; j < qk/2; ++j) {
1350+
const int x0 = (x[i].qs[j] & 0xf);
1351+
const int x1 = (x[i].qs[j] >> 4);
13511352

1352-
for (int l = 0; l < qk; ++l) {
1353-
y[i*qk + l] = qsp[l]*d + m;
1353+
y[i*qk + j + 0 ] = x0*d + m;
1354+
y[i*qk + j + qk/2] = x1*d + m;
13541355
}
13551356
}
13561357
}
13571358

13581359
static void dequantize_row_q4_2(const block_q4_2 * restrict x, float * restrict y, int k) {
1360+
// BORKEN !!!
13591361
static const int qk = QK4_2;
13601362

13611363
assert(qk / 16 == 0);
@@ -1368,7 +1370,7 @@ static void dequantize_row_q4_2(const block_q4_2 * restrict x, float * restrict
13681370
for (int i = 0; i < nb; i++) {
13691371
const float d = GGML_FP16_TO_FP32(x[i].d);
13701372

1371-
const uint8_t * qsp = bytes_from_nibbles_64(qk, x[i].qs, qs);
1373+
const uint8_t * qsp = b4_from_nibbles_64(qk, x[i].qs, qs);
13721374

13731375
for (int l = 0; l < qk; ++l) {
13741376
y[i*qk + l] = (qsp[l] - 8)*d;
@@ -1384,20 +1386,21 @@ static void dequantize_row_q5_0(const block_q5_0 * restrict x, float * restrict
13841386

13851387
const int nb = k / qk;
13861388

1387-
uint64_t qs[QK5_0 / 8];
1388-
13891389
for (int i = 0; i < nb; i++) {
13901390
const float d = GGML_FP16_TO_FP32(x[i].d);
13911391

13921392
uint32_t qh;
13931393
memcpy(&qh, x[i].qh, sizeof(qh));
13941394

1395-
const uint8_t * qsp = bytes_from_nibbles_64(qk, x[i].qs, qs);
1395+
for (int j = 0; j < qk/2; ++j) {
1396+
const uint8_t xh_0 = ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4;
1397+
const uint8_t xh_1 = ((qh & (1u << (j + 16))) >> (j + 12));
13961398

1397-
for (int l = 0; l < qk; ++l) {
1398-
const uint8_t vh = ((qh & (1u << l)) >> l) << 4;
1399+
const int32_t x0 = ((x[i].qs[j] & 0xf) | xh_0) - 16;
1400+
const int32_t x1 = ((x[i].qs[j] >> 4) | xh_1) - 16;
13991401

1400-
y[i*qk + l] = ((qsp[l] | vh) - 16)*d;
1402+
y[i*qk + j + 0 ] = x0*d;
1403+
y[i*qk + j + qk/2] = x1*d;
14011404
}
14021405
}
14031406
}
@@ -2261,17 +2264,16 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
22612264
// scalar
22622265
float sumf = 0.0;
22632266

2264-
uint64_t qs[QK8_0 / 8];
2265-
22662267
for (int i = 0; i < nb; i++) {
2267-
// unpack nibbles into bytes
2268-
const uint8_t * px = bytes_from_nibbles_64(qk, x[i].qs, qs);
2269-
const int8_t * py = y[i].qs;
2268+
const int8_t * py = y[i].qs;
22702269

22712270
int sumi = 0;
22722271

2273-
for (int j = 0; j < qk; ++j) {
2274-
sumi += (px[j] - 8) * py[j];
2272+
for (int j = 0; j < qk/2; ++j) {
2273+
const int v0 = (x[i].qs[j] & 0xf) - 8;
2274+
const int v1 = (x[i].qs[j] >> 4) - 8;
2275+
2276+
sumi += (v0 * py[j]) + (v1 * py[j + qk/2]);
22752277
}
22762278

22772279
sumf += (x[i].d*y[i].d)*sumi;
@@ -2386,16 +2388,16 @@ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void *
23862388
// scalar
23872389
float sumf = 0.0;
23882390

2389-
uint64_t qs[QK8_1 / 8];
2390-
23912391
for (int i = 0; i < nb; i++) {
2392-
const uint8_t * px = bytes_from_nibbles_64(qk, x[i].qs, qs);
2393-
const int8_t * py = y[i].qs;
2392+
const int8_t * py = y[i].qs;
23942393

23952394
int sumi = 0;
23962395

2397-
for (int j = 0; j < qk; ++j) {
2398-
sumi += px[j]*py[j];
2396+
for (int j = 0; j < qk/2; ++j) {
2397+
const int v0 = (x[i].qs[j] & 0xf);
2398+
const int v1 = (x[i].qs[j] >> 4);
2399+
2400+
sumi += (v0 * py[j]) + (v1 * py[j + qk/2]);
23992401
}
24002402

24012403
sumf += (x[i].d*y[i].d)*sumi + x[i].m*(y[i].s0 + y[i].s1);
@@ -2720,22 +2722,22 @@ static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void *
27202722
// scalar
27212723
float sumf = 0.0;
27222724

2723-
uint64_t qs[QK8_0 / 8];
2724-
27252725
for (int i = 0; i < nb; i++) {
2726-
// unpack nibbles into bytes
2727-
const uint8_t * px = bytes_from_nibbles_64(qk, x[i].qs, qs);
2728-
const int8_t * py = y[i].qs;
2726+
const int8_t * py = y[i].qs;
27292727

27302728
uint32_t qh;
27312729
memcpy(&qh, x[i].qh, sizeof(qh));
27322730

27332731
int sumi = 0;
27342732

2735-
for (int j = 0; j < qk; ++j) {
2736-
const int xh = ((qh & (1u << j)) >> j) << 4;
2733+
for (int j = 0; j < qk/2; ++j) {
2734+
const uint8_t xh_0 = ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4;
2735+
const uint8_t xh_1 = ((qh & (1u << (j + 16))) >> (j + 12));
2736+
2737+
const int32_t x0 = ((x[i].qs[j] & 0xf) | xh_0) - 16;
2738+
const int32_t x1 = ((x[i].qs[j] >> 4) | xh_1) - 16;
27372739

2738-
sumi += ((px[j] | xh) - 16)*py[j];
2740+
sumi += (x0 * py[j]) + (x1 * py[j + qk/2]);
27392741
}
27402742

27412743
sumf += (GGML_FP16_TO_FP32(x[i].d)*y[i].d)*sumi;

0 commit comments

Comments
 (0)