Skip to content

Commit 5425e06

Browse files
committed
ggml : alternative Q4_3 implementation using modified Q8_0
1 parent ec805ee commit 5425e06

File tree

1 file changed

+53
-34
lines changed

1 file changed

+53
-34
lines changed

ggml.c

Lines changed: 53 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -656,10 +656,11 @@ static_assert(sizeof(block_q4_3) == 2 * sizeof(ggml_fp16_t) + QK4_3 / 2, "wrong
656656
#define QK8_0 32
657657
typedef struct {
658658
float d; // delta
659-
float s; // d * sum(qs[i])
659+
float s0; // d * sum(qs[i]) low
660+
float s1; // d * sum(qs[i]) high
660661
int8_t qs[QK8_0]; // quants
661662
} block_q8_0;
662-
static_assert(sizeof(block_q8_0) == 2*sizeof(float) + QK8_0, "wrong q8_0 block size/padding");
663+
static_assert(sizeof(block_q8_0) == 3*sizeof(float) + QK8_0, "wrong q8_0 block size/padding");
663664

664665

665666
// reference implementation for deterministic creation of model files
@@ -1299,13 +1300,22 @@ static void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * r
12991300

13001301
y[i].d = d;
13011302

1302-
int sum = 0;
1303-
for (int l = 0; l < QK8_0; ++l) {
1304-
const float v = x[i*QK8_0 + l]*id;
1305-
y[i].qs[l] = roundf(v);
1306-
sum += y[i].qs[l];
1303+
int sum0 = 0;
1304+
int sum1 = 0;
1305+
1306+
for (int l = 0; l < QK8_0/2; ++l) {
1307+
const float v0 = x[i*QK8_0 + l]*id;
1308+
const float v1 = x[i*QK8_0 + QK8_0/2 + l]*id;
1309+
1310+
y[i].qs[ l] = roundf(v0);
1311+
y[i].qs[QK8_0/2 + l] = roundf(v1);
1312+
1313+
sum0 += y[i].qs[ l];
1314+
sum1 += y[i].qs[QK8_0/2 + l];
13071315
}
1308-
y[i].s = d * sum;
1316+
1317+
y[i].s0 = d * sum0;
1318+
y[i].s1 = d * sum1;
13091319
}
13101320
}
13111321

@@ -1335,9 +1345,24 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int
13351345

13361346
y[i].d = d;
13371347

1338-
int32x4_t accv = vdupq_n_s32(0);
1348+
int32x4_t accv0 = vdupq_n_s32(0);
1349+
int32x4_t accv1 = vdupq_n_s32(0);
13391350

1340-
for (int l = 0; l < 8; l++) {
1351+
// low half
1352+
for (int l = 0; l < 4; l++) {
1353+
const float32x4_t v = vmulq_n_f32(srcv[l], id);
1354+
const int32x4_t vi = vcvtnq_s32_f32(v);
1355+
1356+
y[i].qs[4*l + 0] = vgetq_lane_s32(vi, 0);
1357+
y[i].qs[4*l + 1] = vgetq_lane_s32(vi, 1);
1358+
y[i].qs[4*l + 2] = vgetq_lane_s32(vi, 2);
1359+
y[i].qs[4*l + 3] = vgetq_lane_s32(vi, 3);
1360+
1361+
accv0 = vaddq_s32(accv0, vi);
1362+
}
1363+
1364+
// high half
1365+
for (int l = 4; l < 8; l++) {
13411366
const float32x4_t v = vmulq_n_f32(srcv[l], id);
13421367
const int32x4_t vi = vcvtnq_s32_f32(v);
13431368

@@ -1346,12 +1371,17 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int
13461371
y[i].qs[4*l + 2] = vgetq_lane_s32(vi, 2);
13471372
y[i].qs[4*l + 3] = vgetq_lane_s32(vi, 3);
13481373

1349-
accv = vaddq_s32(accv, vi);
1374+
accv1 = vaddq_s32(accv1, vi);
13501375
}
1351-
int32_t sum = vaddvq_s32(accv);
1352-
y[i].s = d * sum;
1376+
1377+
const int32_t sum0 = vaddvq_s32(accv0);
1378+
const int32_t sum1 = vaddvq_s32(accv1);
1379+
1380+
y[i].s0 = d * sum0;
1381+
y[i].s1 = d * sum1;
13531382
}
13541383
#elif defined(__AVX2__) || defined(__AVX__)
1384+
// TODO !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
13551385
for (int i = 0; i < nb; i++) {
13561386
// Load elements into 4 AVX vectors
13571387
__m256 v0 = _mm256_loadu_ps( x );
@@ -2395,7 +2425,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
23952425
const block_q8_0 * restrict y0 = &y[i + 0];
23962426
const block_q8_0 * restrict y1 = &y[i + 1];
23972427

2398-
sum8 += x0->d * y0->s + x1->d * y1->s;
2428+
sum8 += x0->d * (y0->s0 + y0->s1) + x1->d * (y1->s0 + y1->s1);
23992429

24002430
const uint8x16_t m4b = vdupq_n_u8(0xf);
24012431

@@ -2562,7 +2592,7 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
25622592
const block_q8_0 * restrict y0 = &y[i + 0];
25632593
const block_q8_0 * restrict y1 = &y[i + 1];
25642594

2565-
summs += x0->m * y0->s + x1->m * y1->s;
2595+
summs += x0->m * (y0->s0 + y0->s1) + x1->m * (y1->s0 + y1->s1);
25662596

25672597
const uint8x16_t m4b = vdupq_n_u8(0xf);
25682598

@@ -2589,8 +2619,8 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
25892619

25902620
#if defined(__ARM_FEATURE_DOTPROD)
25912621
// dot product into int32x4_t
2592-
const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0l, v1_0ls), v0_0h, v1_0hs);
2593-
const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1l, v1_1ls), v0_1h, v1_1hs);
2622+
const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0lz, v1_0l), v0_0hz, v1_0h);
2623+
const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1lz, v1_1l), v0_1hz, v1_1h);
25942624

25952625
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), x0->d*y0->d);
25962626
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), x1->d*y1->d);
@@ -2845,6 +2875,8 @@ static void ggml_vec_dot_q4_3_q8_0(const int n, float * restrict s, const void *
28452875
float32x4_t sumv0 = vdupq_n_f32(0.0f);
28462876
float32x4_t sumv1 = vdupq_n_f32(0.0f);
28472877

2878+
float summs = 0.0f;
2879+
28482880
for (int i = 0; i < nb; i += 2) {
28492881
const block_q4_3 * restrict x0_0 = &x[2*(i + 0) + 0];
28502882
const block_q4_3 * restrict x0_1 = &x[2*(i + 0) + 1];
@@ -2854,18 +2886,16 @@ static void ggml_vec_dot_q4_3_q8_0(const int n, float * restrict s, const void *
28542886
const block_q8_0 * restrict y0 = &y[i + 0];
28552887
const block_q8_0 * restrict y1 = &y[i + 1];
28562888

2889+
summs += GGML_FP16_TO_FP32(x0_0->m) * y0->s0 + GGML_FP16_TO_FP32(x0_1->m) * y0->s1;
2890+
summs += GGML_FP16_TO_FP32(x1_0->m) * y1->s0 + GGML_FP16_TO_FP32(x1_1->m) * y1->s1;
2891+
28572892
const uint8x16_t m4b = vdupq_n_u8(0xf);
28582893

28592894
const float x0_0d = GGML_FP16_TO_FP32(x0_0->d);
28602895
const float x0_1d = GGML_FP16_TO_FP32(x0_1->d);
28612896
const float x1_0d = GGML_FP16_TO_FP32(x1_0->d);
28622897
const float x1_1d = GGML_FP16_TO_FP32(x1_1->d);
28632898

2864-
const float x0_0m = GGML_FP16_TO_FP32(x0_0->m);
2865-
const float x0_1m = GGML_FP16_TO_FP32(x0_1->m);
2866-
const float x1_0m = GGML_FP16_TO_FP32(x1_0->m);
2867-
const float x1_1m = GGML_FP16_TO_FP32(x1_1->m);
2868-
28692899
const uint8x16_t v0_0 = vcombine_u8(vld1_u8(x0_0->qs), vld1_u8(x0_1->qs));
28702900
const uint8x16_t v0_1 = vcombine_u8(vld1_u8(x1_0->qs), vld1_u8(x1_1->qs));
28712901

@@ -2887,17 +2917,6 @@ static void ggml_vec_dot_q4_3_q8_0(const int n, float * restrict s, const void *
28872917
const int8x16_t v1_1l = vld1q_s8(y1->qs);
28882918
const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
28892919

2890-
const int16x8_t sy0_0 = vaddq_s16(vmovl_s8(vget_low_s8(v1_0l)), vmovl_s8(vget_high_s8(v1_0l)));
2891-
const int16x8_t sy0_1 = vaddq_s16(vmovl_s8(vget_low_s8(v1_0h)), vmovl_s8(vget_high_s8(v1_0h)));
2892-
2893-
const int16x8_t sy1_0 = vaddq_s16(vmovl_s8(vget_low_s8(v1_1l)), vmovl_s8(vget_high_s8(v1_1l)));
2894-
const int16x8_t sy1_1 = vaddq_s16(vmovl_s8(vget_low_s8(v1_1h)), vmovl_s8(vget_high_s8(v1_1h)));
2895-
2896-
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddl_s16(vget_low_s16(sy0_0), vget_high_s16(sy0_0))), x0_0m*y0->d);
2897-
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddl_s16(vget_low_s16(sy0_1), vget_high_s16(sy0_1))), x0_1m*y0->d);
2898-
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddl_s16(vget_low_s16(sy1_0), vget_high_s16(sy1_0))), x1_0m*y1->d);
2899-
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddl_s16(vget_low_s16(sy1_1), vget_high_s16(sy1_1))), x1_1m*y1->d);
2900-
29012920
#if defined(__ARM_FEATURE_DOTPROD)
29022921
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0lz, v1_0l)), x0_0d*y0->d);
29032922
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0hz, v1_0h)), x0_1d*y0->d);
@@ -2926,7 +2945,7 @@ static void ggml_vec_dot_q4_3_q8_0(const int n, float * restrict s, const void *
29262945
#endif
29272946
}
29282947

2929-
*s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
2948+
sumf = vaddvq_f32(vaddq_f32(sumv0, sumv1)) + summs;
29302949
#elif defined(__AVX2__)
29312950
// Initialize accumulator with zeros
29322951
__m256 acc = _mm256_setzero_ps();

0 commit comments

Comments
 (0)