Skip to content

Commit c7af904

Browse files
committed
ggml : remove Q5_1 bit shuffling (ARM NEON + scalar)
1 parent 39bb8e7 commit c7af904

File tree

1 file changed

+66
-106
lines changed

1 file changed

+66
-106
lines changed

ggml.c

Lines changed: 66 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -851,8 +851,7 @@ static_assert(sizeof(block_q8_1) == 3*sizeof(float) + QK8_1, "wrong q8_1 block s
851851
static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int k) {
852852
static const int qk = QK4_0;
853853

854-
assert(qk / 16 == 0);
855-
assert( k % qk == 0);
854+
assert(k % qk == 0);
856855

857856
const int nb = k / qk;
858857

@@ -873,20 +872,16 @@ static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * r
873872

874873
y[i].d = d;
875874

876-
uint64_t qs[QK4_0 / 16] = {0};
877-
878875
for (int l = 0; l < qk/2; ++l) {
879876
const float x0 = x[i*qk + 0 + l]*id;
880877
const float x1 = x[i*qk + qk/2 + l]*id;
881878

882-
const uint64_t xi0 = MIN(15, (int8_t)(x0 + 8.5f));
883-
const uint64_t xi1 = MIN(15, (int8_t)(x1 + 8.5f));
879+
const uint8_t xi0 = MIN(15, (int8_t)(x0 + 8.5f));
880+
const uint8_t xi1 = MIN(15, (int8_t)(x1 + 8.5f));
884881

885-
qs[l/8] |= xi0 << (8*(l & 7));
886-
qs[l/8] |= xi1 << (8*(l & 7) + 4);
882+
y[i].qs[l] = xi0;
883+
y[i].qs[l] |= xi1 << 4;
887884
}
888-
889-
memcpy(y[i].qs, qs, qk/2);
890885
}
891886
}
892887

@@ -897,8 +892,7 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict y, int k
897892
static void quantize_row_q4_1_reference(const float * restrict x, block_q4_1 * restrict y, int k) {
898893
const int qk = QK4_1;
899894

900-
assert(qk / 16 == 0);
901-
assert( k % qk == 0);
895+
assert(k % qk == 0);
902896

903897
const int nb = k / qk;
904898

@@ -919,20 +913,16 @@ static void quantize_row_q4_1_reference(const float * restrict x, block_q4_1 * r
919913
y[i].d = d;
920914
y[i].m = min;
921915

922-
uint64_t qs[QK4_1 / 16] = {0};
923-
924916
for (int l = 0; l < qk/2; ++l) {
925917
const float x0 = (x[0 + l] - min)*id;
926918
const float x1 = (x[qk/2 + l] - min)*id;
927919

928-
const uint64_t xi0 = MIN(15, (int8_t)(x0 + 0.5f));
929-
const uint64_t xi1 = MIN(15, (int8_t)(x1 + 0.5f));
920+
const uint8_t xi0 = MIN(15, (int8_t)(x0 + 0.5f));
921+
const uint8_t xi1 = MIN(15, (int8_t)(x1 + 0.5f));
930922

931-
qs[l/8] |= xi0 << (8*(l & 7));
932-
qs[l/8] |= xi1 << (8*(l & 7) + 4);
923+
y[i].qs[l] = xi0;
924+
y[i].qs[l] |= xi1 << 4;
933925
}
934-
935-
memcpy(y[i].qs, qs, qk/2);
936926
}
937927
}
938928

@@ -944,8 +934,7 @@ static void quantize_row_q4_1(const float * restrict x, void * restrict y, int k
944934
static void quantize_row_q4_2_reference(const float * restrict x, block_q4_2 * restrict y, int k) {
945935
static const int qk = QK4_2;
946936

947-
assert(qk / 16 == 0);
948-
assert( k % qk == 0);
937+
assert(k % qk == 0);
949938

950939
const int nb = k / qk;
951940

@@ -990,8 +979,7 @@ static void quantize_row_q4_2(const float * restrict x, void * restrict y, int k
990979
static void quantize_row_q5_0_reference(const float * restrict x, block_q5_0 * restrict y, int k) {
991980
static const int qk = QK5_0;
992981

993-
assert(qk / 16 == 0);
994-
assert( k % qk == 0);
982+
assert(k % qk == 0);
995983

996984
const int nb = k / qk;
997985

@@ -1013,24 +1001,21 @@ static void quantize_row_q5_0_reference(const float * restrict x, block_q5_0 * r
10131001
y[i].d = d;
10141002

10151003
uint32_t qh = 0;
1016-
uint64_t qs[QK5_0 / 16] = {0};
10171004

10181005
for (int l = 0; l < qk/2; ++l) {
10191006
const float x0 = x[i*qk + 0 + l]*id;
10201007
const float x1 = x[i*qk + qk/2 + l]*id;
10211008

1022-
const uint64_t xi0 = MIN(31, (int8_t)(x0 + 16.5f));
1023-
const uint64_t xi1 = MIN(31, (int8_t)(x1 + 16.5f));
1009+
const uint8_t xi0 = MIN(31, (int8_t)(x0 + 16.5f));
1010+
const uint8_t xi1 = MIN(31, (int8_t)(x1 + 16.5f));
10241011

1025-
qs[l/8] |= xi0 << (8*(l & 7));
1026-
qs[l/8] |= xi1 << (8*(l & 7) + 4);
1012+
y[i].qs[l] = (xi0 & 0x0F) | ((xi1 & 0x0F) << 4);
10271013

10281014
// get the 5-th bit and store it in qh at the right position
10291015
qh |= ((xi0 & 0x10) >> 4) << (l + 0);
10301016
qh |= ((xi1 & 0x10) >> 4) << (l + qk/2);
10311017
}
10321018

1033-
memcpy( y[i].qs, qs, qk/2);
10341019
memcpy(&y[i].qh, &qh, sizeof(qh));
10351020
}
10361021
}
@@ -1040,50 +1025,50 @@ static void quantize_row_q5_0(const float * restrict x, void * restrict y, int k
10401025
}
10411026

10421027
static void quantize_row_q5_1_reference(const float * restrict x, block_q5_1 * restrict y, int k) {
1043-
assert(k % QK5_1 == 0);
1044-
const int nb = k / QK5_1;
1028+
const int qk = QK5_1;
1029+
1030+
assert(k % qk == 0);
1031+
1032+
const int nb = k / qk;
10451033

10461034
for (int i = 0; i < nb; i++) {
10471035
float min = FLT_MAX;
10481036
float max = -FLT_MAX;
10491037

1050-
for (int l = 0; l < QK5_1; l++) {
1051-
const float v = x[i*QK5_1 + l];
1038+
for (int l = 0; l < qk; l++) {
1039+
const float v = x[i*qk + l];
1040+
10521041
if (v < min) min = v;
10531042
if (v > max) max = v;
10541043
}
10551044

1056-
const float d = (max - min) / ((1 << 5) - 1);
1045+
const float d = (max - min) / ((1 << 5) - 1);
10571046
const float id = d ? 1.0f/d : 0.0f;
10581047

10591048
y[i].d = GGML_FP32_TO_FP16(d);
10601049
y[i].m = GGML_FP32_TO_FP16(min);
10611050

10621051
uint32_t qh = 0;
10631052

1064-
for (int l = 0; l < QK5_1; l += 2) {
1065-
const float v0 = (x[i*QK5_1 + l + 0] - min)*id;
1066-
const float v1 = (x[i*QK5_1 + l + 1] - min)*id;
1053+
for (int l = 0; l < qk/2; ++l) {
1054+
const float x0 = (x[i*qk + 0 + l] - min)*id;
1055+
const float x1 = (x[i*qk + qk/2 + l] - min)*id;
10671056

1068-
const uint32_t vi0 = (int) (v0 + 0.5f);
1069-
const uint32_t vi1 = (int) (v1 + 0.5f);
1057+
const uint8_t xi0 = (uint8_t)(x0 + 0.5f);
1058+
const uint8_t xi1 = (uint8_t)(x1 + 0.5f);
10701059

1071-
y[i].qs[l/2] = (vi0 & 0x0F) | ((vi1 & 0x0F) << 4);
1060+
y[i].qs[l] = (xi0 & 0x0F) | ((xi1 & 0x0F) << 4);
10721061

10731062
// get the 5-th bit and store it in qh at the right position
1074-
qh |= ((vi0 & 0x10) >> 4) << (l + 0);
1075-
qh |= ((vi1 & 0x10) >> 4) << (l + 1);
1063+
qh |= ((xi0 & 0x10) >> 4) << (l + 0);
1064+
qh |= ((xi1 & 0x10) >> 4) << (l + qk/2);
10761065
}
10771066

10781067
memcpy(&y[i].qh, &qh, sizeof(y[i].qh));
10791068
}
10801069
}
10811070

1082-
static void quantize_row_q5_1(const float * restrict x, void * restrict vy, int k) {
1083-
assert(k % QK5_1 == 0);
1084-
1085-
block_q5_1 * restrict y = vy;
1086-
1071+
static void quantize_row_q5_1(const float * restrict x, void * restrict y, int k) {
10871072
quantize_row_q5_1_reference(x, y, k);
10881073
}
10891074

@@ -1443,8 +1428,7 @@ static void quantize_row_q8_1(const float * restrict x, void * restrict vy, int
14431428
static void dequantize_row_q4_0(const block_q4_0 * restrict x, float * restrict y, int k) {
14441429
static const int qk = QK4_0;
14451430

1446-
assert(qk / 16 == 0);
1447-
assert( k % qk == 0);
1431+
assert(k % qk == 0);
14481432

14491433
const int nb = k / qk;
14501434

@@ -1464,8 +1448,7 @@ static void dequantize_row_q4_0(const block_q4_0 * restrict x, float * restrict
14641448
static void dequantize_row_q4_1(const block_q4_1 * restrict x, float * restrict y, int k) {
14651449
static const int qk = QK4_1;
14661450

1467-
assert(qk / 16 == 0);
1468-
assert( k % qk == 0);
1451+
assert(k % qk == 0);
14691452

14701453
const int nb = k / qk;
14711454

@@ -1487,8 +1470,7 @@ static void dequantize_row_q4_2(const block_q4_2 * restrict x, float * restrict
14871470
// BORKEN !!!
14881471
static const int qk = QK4_2;
14891472

1490-
assert(qk / 16 == 0);
1491-
assert( k % qk == 0);
1473+
assert(k % qk == 0);
14921474

14931475
const int nb = k / qk;
14941476

@@ -1508,8 +1490,7 @@ static void dequantize_row_q4_2(const block_q4_2 * restrict x, float * restrict
15081490
static void dequantize_row_q5_0(const block_q5_0 * restrict x, float * restrict y, int k) {
15091491
static const int qk = QK4_0;
15101492

1511-
assert(qk / 16 == 0);
1512-
assert( k % qk == 0);
1493+
assert(k % qk == 0);
15131494

15141495
const int nb = k / qk;
15151496

@@ -1532,39 +1513,29 @@ static void dequantize_row_q5_0(const block_q5_0 * restrict x, float * restrict
15321513
}
15331514
}
15341515

1535-
static void dequantize_row_q5_1(const void * restrict vx, float * restrict y, int k) {
1536-
assert(k % QK5_1 == 0);
1537-
const int nb = k / QK5_1;
1516+
static void dequantize_row_q5_1(const block_q5_1 * restrict x, float * restrict y, int k) {
1517+
static const int qk = QK5_1;
15381518

1539-
const block_q5_1 * restrict x = vx;
1519+
assert(k % qk == 0);
1520+
1521+
const int nb = k / qk;
15401522

15411523
for (int i = 0; i < nb; i++) {
15421524
const float d = GGML_FP16_TO_FP32(x[i].d);
15431525
const float m = GGML_FP16_TO_FP32(x[i].m);
15441526

1545-
const uint8_t * restrict pp = x[i].qs;
1546-
15471527
uint32_t qh;
15481528
memcpy(&qh, x[i].qh, sizeof(qh));
15491529

1550-
for (int l = 0; l < QK5_1; l += 2) {
1551-
const uint8_t vi = pp[l/2];
1552-
1553-
// extract the 5-th bit from qh
1554-
const uint8_t vh0 = ((qh & (1u << (l + 0))) >> (l + 0)) << 4;
1555-
const uint8_t vh1 = ((qh & (1u << (l + 1))) >> (l + 1)) << 4;
1556-
1557-
const uint8_t vi0 = (vi & 0x0F) | vh0;
1558-
const uint8_t vi1 = (vi >> 4) | vh1;
1559-
1560-
const float v0 = vi0*d + m;
1561-
const float v1 = vi1*d + m;
1530+
for (int j = 0; j < qk/2; ++j) {
1531+
const uint8_t xh_0 = ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4;
1532+
const uint8_t xh_1 = ((qh & (1u << (j + 16))) >> (j + 12));
15621533

1563-
y[i*QK5_1 + l + 0] = v0;
1564-
y[i*QK5_1 + l + 1] = v1;
1534+
const int x0 = (x[i].qs[j] & 0xf) | xh_0;
1535+
const int x1 = (x[i].qs[j] >> 4) | xh_1;
15651536

1566-
assert(!isnan(y[i*QK5_1 + l + 0]));
1567-
assert(!isnan(y[i*QK5_1 + l + 1]));
1537+
y[i*qk + j + 0 ] = x0*d + m;
1538+
y[i*qk + j + qk/2] = x1*d + m;
15681539
}
15691540
}
15701541
}
@@ -1627,7 +1598,7 @@ static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = {
16271598
.vec_dot_type = GGML_TYPE_Q8_0,
16281599
},
16291600
[GGML_TYPE_Q5_1] = {
1630-
.dequantize_row_q = dequantize_row_q5_1,
1601+
.dequantize_row_q = (dequantize_row_q_t) dequantize_row_q5_1,
16311602
.quantize_row_q = quantize_row_q5_1,
16321603
.quantize_row_q_reference = (quantize_row_q_t) quantize_row_q5_1_reference,
16331604
.quantize_row_q_dot = quantize_row_q8_1,
@@ -2875,11 +2846,12 @@ static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void *
28752846
}
28762847

28772848
static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
2878-
const int nb = n / QK8_1;
2849+
const int qk = QK8_1;
2850+
const int nb = n / qk;
28792851

2880-
assert(n % QK8_1 == 0);
2852+
assert(n % qk == 0);
28812853
assert(nb % 2 == 0);
2882-
assert(QK8_1 == QK5_1);
2854+
assert(qk == QK5_1);
28832855

28842856
const block_q5_1 * restrict x = vx;
28852857
const block_q8_1 * restrict y = vy;
@@ -2915,13 +2887,9 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void *
29152887
const int8x16_t v0l = vreinterpretq_s8_u8(vandq_u8 (v0, vdupq_n_u8(0x0F)));
29162888
const int8x16_t v0h = vreinterpretq_s8_u8(vshrq_n_u8(v0, 4));
29172889

2918-
// interleave
2919-
const int8x16_t v0lz = vzip1q_s8(v0l, v0h);
2920-
const int8x16_t v0hz = vzip2q_s8(v0l, v0h);
2921-
29222890
// add
2923-
const int8x16_t v0lf = vorrq_s8(v0lz, qhl);
2924-
const int8x16_t v0hf = vorrq_s8(v0hz, qhh);
2891+
const int8x16_t v0lf = vorrq_s8(v0l, qhl);
2892+
const int8x16_t v0hf = vorrq_s8(v0h, qhh);
29252893

29262894
// load y
29272895
const int8x16_t v1l = vld1q_s8(y0->qs);
@@ -3044,36 +3012,28 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void *
30443012

30453013
*s = hsum_float_8(acc) + summs;
30463014
#else
3015+
// scalar
30473016
float sumf = 0.0;
30483017

30493018
for (int i = 0; i < nb; i++) {
3050-
const uint8_t * restrict x0 = x[i].qs;
3051-
const int8_t * restrict y0 = y[i].qs;
3019+
const int8_t * py = y[i].qs;
30523020

30533021
uint32_t qh;
30543022
memcpy(&qh, x[i].qh, sizeof(qh));
30553023

3056-
const float d = GGML_FP16_TO_FP32(x[i].d);
3057-
const float m = GGML_FP16_TO_FP32(x[i].m);
3058-
3059-
int sxy = 0;
3060-
3061-
for (int j = 0; j < QK8_1/2; j++) {
3062-
const uint8_t v0 = x0[j];
3063-
3064-
const int x0_0h = ((qh & (1u << (2*j + 0))) >> (2*j + 0)) << 4;
3065-
const int x1_0h = ((qh & (1u << (2*j + 1))) >> (2*j + 1)) << 4;
3024+
int sumi = 0;
30663025

3067-
const int x0_0 = (v0 & 0x0F) | x0_0h;
3068-
const int x1_0 = (v0 >> 4) | x1_0h;
3026+
for (int j = 0; j < qk/2; ++j) {
3027+
const uint8_t xh_0 = ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4;
3028+
const uint8_t xh_1 = ((qh & (1u << (j + 16))) >> (j + 12));
30693029

3070-
const int y0_0 = y0[2*j + 0];
3071-
const int y1_0 = y0[2*j + 1];
3030+
const int32_t x0 = (x[i].qs[j] & 0xF) | xh_0;
3031+
const int32_t x1 = (x[i].qs[j] >> 4) | xh_1;
30723032

3073-
sxy += x0_0*y0_0 + x1_0*y1_0;
3033+
sumi += (x0 * py[j]) + (x1 * py[j + qk/2]);
30743034
}
30753035

3076-
sumf += (d*sxy)*y[i].d + m*(y[i].s0 + y[i].s1);
3036+
sumf += (GGML_FP16_TO_FP32(x[i].d)*y[i].d)*sumi + GGML_FP16_TO_FP32(x[i].m)*(y[i].s0 + y[i].s1);
30773037
}
30783038

30793039
*s = sumf;

0 commit comments

Comments
 (0)