Skip to content

Commit e5826a4

Browse files
committed
ggml : remove Q4_2 bit shuffling (WIP, BROKEN)
1 parent 34cdd78 commit e5826a4

File tree

1 file changed

+47
-70
lines changed

1 file changed

+47
-70
lines changed

ggml.c

Lines changed: 47 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -877,7 +877,7 @@ static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * r
877877
static const int qk = QK4_0;
878878

879879
assert(qk / 16 == 0);
880-
assert(k % qk == 0);
880+
assert( k % qk == 0);
881881

882882
const int nb = k / qk;
883883

@@ -912,7 +912,7 @@ static void quantize_row_q4_1_reference(const float * restrict x, block_q4_1 * r
912912
const int qk = QK4_1;
913913

914914
assert(qk / 16 == 0);
915-
assert(k % qk == 0);
915+
assert( k % qk == 0);
916916

917917
const int nb = k / qk;
918918

@@ -945,48 +945,37 @@ static void quantize_row_q4_1(const float * restrict x, void * restrict y, int k
945945

946946
// reference implementation for deterministic creation of model files
947947
static void quantize_row_q4_2_reference(const float * restrict x, block_q4_2 * restrict y, int k) {
948-
assert(k % QK4_2 == 0);
948+
static const int qk = QK4_2;
949949

950-
const int nb = k / QK4_2;
950+
assert(qk / 16 == 0);
951+
assert( k % qk == 0);
952+
953+
const int nb = k / qk;
951954

952955
for (int i = 0; i < nb; i++) {
953956
float amax = 0.0f; // absolute max
954-
float max = 0.0f;
957+
float max = 0.0f;
955958

956-
for (int l = 0; l < QK4_2; l++) {
957-
const float v = x[i*QK4_2 + l];
959+
for (int l = 0; l < qk; l++) {
960+
const float v = x[i*qk + l];
958961
if (amax < fabsf(v)) {
959962
amax = fabsf(v);
960-
max = v;
963+
max = v;
961964
}
962965
}
963966

964-
const float d = max / -8;
965-
967+
const float d = max / -8;
966968
const float id = d ? 1.0f/d : 0.0f;
967969

968970
y[i].d = GGML_FP32_TO_FP16(d);
969971

970-
for (int l = 0; l < QK4_2; l += 2) {
971-
const float v0 = x[i*QK4_2 + l + 0]*id;
972-
const float v1 = x[i*QK4_2 + l + 1]*id;
973-
974-
const uint8_t vi0 = MIN(15, (uint8_t)(v0 + 8.5f));
975-
const uint8_t vi1 = MIN(15, (uint8_t)(v1 + 8.5f));
972+
uint64_t qs[QK4_2 / 16] = {0};
976973

977-
assert(vi0 < 16);
978-
assert(vi1 < 16);
979-
980-
y[i].qs[l/2] = vi0 | (vi1 << 4);
981-
}
974+
nibbles_from_floats_64_0(qk, x + i*qk, id, y[i].qs, qs);
982975
}
983976
}
984977

985-
static void quantize_row_q4_2(const float * restrict x, void * restrict vy, int k) {
986-
assert(k % QK4_2 == 0);
987-
988-
block_q4_2 * restrict y = vy;
989-
978+
static void quantize_row_q4_2(const float * restrict x, void * restrict y, int k) {
990979
quantize_row_q4_2_reference(x, y, k);
991980
}
992981

@@ -1324,7 +1313,7 @@ static void dequantize_row_q4_0(const block_q4_0 * restrict x, float * restrict
13241313
static const int qk = QK4_0;
13251314

13261315
assert(qk / 16 == 0);
1327-
assert(k % qk == 0);
1316+
assert( k % qk == 0);
13281317

13291318
const int nb = k / qk;
13301319

@@ -1345,7 +1334,7 @@ static void dequantize_row_q4_1(const block_q4_1 * restrict x, float * restrict
13451334
static const int qk = QK4_1;
13461335

13471336
assert(qk / 16 == 0);
1348-
assert(k % qk == 0);
1337+
assert( k % qk == 0);
13491338

13501339
const int nb = k / qk;
13511340

@@ -1363,31 +1352,23 @@ static void dequantize_row_q4_1(const block_q4_1 * restrict x, float * restrict
13631352
}
13641353
}
13651354

1366-
static void dequantize_row_q4_2(const void * restrict vx, float * restrict y, int k) {
1367-
assert(k % QK4_2 == 0);
1368-
const int nb = k / QK4_2;
1369-
1370-
const block_q4_2 * restrict x = vx;
1371-
1372-
for (int i = 0; i < nb; i++) {
1373-
const float d = GGML_FP16_TO_FP32(x[i].d);
1355+
static void dequantize_row_q4_2(const block_q4_2 * restrict x, float * restrict y, int k) {
1356+
static const int qk = QK4_2;
13741357

1375-
const uint8_t * restrict pp = x[i].qs;
1358+
assert(qk / 16 == 0);
1359+
assert( k % qk == 0);
13761360

1377-
for (int l = 0; l < QK4_2; l += 2) {
1378-
const uint8_t vi = pp[l/2];
1361+
const int nb = k / qk;
13791362

1380-
const int8_t vi0 = vi & 0x0F;
1381-
const int8_t vi1 = vi >> 4;
1363+
uint64_t qs[QK4_2 / 8];
13821364

1383-
const float v0 = (vi0 - 8)*d;
1384-
const float v1 = (vi1 - 8)*d;
1365+
for (int i = 0; i < nb; i++) {
1366+
const float d = GGML_FP16_TO_FP32(x[i].d);
13851367

1386-
y[i*QK4_2 + l + 0] = v0;
1387-
y[i*QK4_2 + l + 1] = v1;
1368+
const uint8_t * qsp = bytes_from_nibbles_64(qk, x[i].qs, qs);
13881369

1389-
assert(!isnan(y[i*QK4_2 + l + 0]));
1390-
assert(!isnan(y[i*QK4_2 + l + 1]));
1370+
for (int l = 0; l < qk; ++l) {
1371+
y[i*qk + l] = (qsp[l] - 8)*d;
13911372
}
13921373
}
13931374
}
@@ -1507,7 +1488,7 @@ static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = {
15071488
.vec_dot_type = GGML_TYPE_Q8_1,
15081489
},
15091490
[GGML_TYPE_Q4_2] = {
1510-
.dequantize_row_q = dequantize_row_q4_2,
1491+
.dequantize_row_q = (dequantize_row_q_t) dequantize_row_q4_2,
15111492
.quantize_row_q = quantize_row_q4_2,
15121493
.quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_2_reference,
15131494
.quantize_row_q_dot = quantize_row_q8_0,
@@ -2432,11 +2413,13 @@ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void *
24322413
}
24332414

24342415
static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
2435-
const int nb = n / QK8_0;
2416+
const int qk = QK8_0;
2417+
const int nb = n / qk;
24362418

2437-
assert(n % QK8_0 == 0);
2419+
assert(n % qk == 0);
24382420
assert(nb % 2 == 0);
2439-
assert(QK8_0 == 2*QK4_2);
2421+
2422+
assert(qk == 2*QK4_2);
24402423

24412424
const block_q4_2 * restrict x = vx;
24422425
const block_q8_0 * restrict y = vy;
@@ -2472,12 +2455,6 @@ static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void *
24722455
const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b);
24732456
const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b);
24742457

2475-
// interleave
2476-
const int8x16_t v0_0lz = vzip1q_s8(v0_0ls, v0_0hs);
2477-
const int8x16_t v0_0hz = vzip2q_s8(v0_0ls, v0_0hs);
2478-
const int8x16_t v0_1lz = vzip1q_s8(v0_1ls, v0_1hs);
2479-
const int8x16_t v0_1hz = vzip2q_s8(v0_1ls, v0_1hs);
2480-
24812458
// load y
24822459
const int8x16_t v1_0l = vld1q_s8(y0->qs);
24832460
const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
@@ -2486,22 +2463,22 @@ static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void *
24862463

24872464
#if defined(__ARM_FEATURE_DOTPROD)
24882465
sumv0 = vmlaq_n_f32(sumv0, vaddq_f32(
2489-
vmulq_n_f32(vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0lz, v1_0l)), GGML_FP16_TO_FP32(x0_0->d)),
2490-
vmulq_n_f32(vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0hz, v1_0h)), GGML_FP16_TO_FP32(x0_1->d))), y0->d);
2466+
vmulq_n_f32(vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0l)), GGML_FP16_TO_FP32(x0_0->d)),
2467+
vmulq_n_f32(vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0hs, v1_0h)), GGML_FP16_TO_FP32(x0_1->d))), y0->d);
24912468

24922469
sumv1 = vmlaq_n_f32(sumv1, vaddq_f32(
2493-
vmulq_n_f32(vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_1lz, v1_1l)), GGML_FP16_TO_FP32(x1_0->d)),
2494-
vmulq_n_f32(vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_1hz, v1_1h)), GGML_FP16_TO_FP32(x1_1->d))), y1->d);
2470+
vmulq_n_f32(vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1l)), GGML_FP16_TO_FP32(x1_0->d)),
2471+
vmulq_n_f32(vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_1hs, v1_1h)), GGML_FP16_TO_FP32(x1_1->d))), y1->d);
24952472
#else
2496-
const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lz), vget_low_s8 (v1_0l));
2497-
const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lz), vget_high_s8(v1_0l));
2498-
const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hz), vget_low_s8 (v1_0h));
2499-
const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hz), vget_high_s8(v1_0h));
2500-
2501-
const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1lz), vget_low_s8 (v1_1l));
2502-
const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1lz), vget_high_s8(v1_1l));
2503-
const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hz), vget_low_s8 (v1_1h));
2504-
const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hz), vget_high_s8(v1_1h));
2473+
const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0l));
2474+
const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0l));
2475+
const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hs), vget_low_s8 (v1_0h));
2476+
const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0h));
2477+
2478+
const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1ls), vget_low_s8 (v1_1l));
2479+
const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1l));
2480+
const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hs), vget_low_s8 (v1_1h));
2481+
const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1h));
25052482

25062483
const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
25072484
const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));

0 commit comments

Comments
 (0)