Skip to content

Commit 96e5d02

Browse files
committed
ggml : simplify Q8_1 - no need for low / high sums anymore
1 parent f9bbbe3 commit 96e5d02

File tree

1 file changed

+22
-47
lines changed

1 file changed

+22
-47
lines changed

ggml.c

Lines changed: 22 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -718,12 +718,11 @@ static_assert(sizeof(block_q8_0) == sizeof(float) + QK8_0, "wrong q8_0 block siz
718718

719719
#define QK8_1 32
720720
typedef struct {
721-
float d; // delta
722-
float s0; // d * sum(qs[i]) low
723-
float s1; // d * sum(qs[i]) high
724-
int8_t qs[QK8_1]; // quants
721+
float d; // delta
722+
float s; // d * sum(qs[i])
723+
int8_t qs[QK8_1]; // quants
725724
} block_q8_1;
726-
static_assert(sizeof(block_q8_1) == 3*sizeof(float) + QK8_1, "wrong q8_1 block size/padding");
725+
static_assert(sizeof(block_q8_1) == 2*sizeof(float) + QK8_1, "wrong q8_1 block size/padding");
727726

728727
// reference implementation for deterministic creation of model files
729728
static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int k) {
@@ -1078,8 +1077,7 @@ static void quantize_row_q8_1_reference(const float * restrict x, block_q8_1 * r
10781077

10791078
y[i].d = d;
10801079

1081-
int sum0 = 0;
1082-
int sum1 = 0;
1080+
int sum = 0;
10831081

10841082
for (int j = 0; j < QK8_1/2; ++j) {
10851083
const float v0 = x[i*QK8_1 + 2*j + 0]*id;
@@ -1088,12 +1086,11 @@ static void quantize_row_q8_1_reference(const float * restrict x, block_q8_1 * r
10881086
y[i].qs[ j] = v0 + 0.5f;
10891087
y[i].qs[QK8_1/2 + j] = v1 + 0.5f;
10901088

1091-
sum0 += y[i].qs[ j];
1092-
sum1 += y[i].qs[QK8_1/2 + j];
1089+
sum += y[i].qs[ j];
1090+
sum += y[i].qs[QK8_1/2 + j];
10931091
}
10941092

1095-
y[i].s0 = d * sum0;
1096-
y[i].s1 = d * sum1;
1093+
y[i].s = d * sum;
10971094
}
10981095
}
10991096

@@ -1123,24 +1120,9 @@ static void quantize_row_q8_1(const float * restrict x, void * restrict vy, int
11231120

11241121
y[i].d = d;
11251122

1126-
int32x4_t accv0 = vdupq_n_s32(0);
1127-
int32x4_t accv1 = vdupq_n_s32(0);
1128-
1129-
// low half
1130-
for (int j = 0; j < 4; j++) {
1131-
const float32x4_t v = vmulq_n_f32(srcv[j], id);
1132-
const int32x4_t vi = vcvtnq_s32_f32(v);
1123+
int32x4_t accv = vdupq_n_s32(0);
11331124

1134-
y[i].qs[ 2*j + 0] = vgetq_lane_s32(vi, 0);
1135-
y[i].qs[16 + 2*j + 0] = vgetq_lane_s32(vi, 1);
1136-
y[i].qs[ 2*j + 1] = vgetq_lane_s32(vi, 2);
1137-
y[i].qs[16 + 2*j + 1] = vgetq_lane_s32(vi, 3);
1138-
1139-
accv0 = vaddq_s32(accv0, vi);
1140-
}
1141-
1142-
// high half
1143-
for (int j = 4; j < 8; j++) {
1125+
for (int j = 0; j < 8; j++) {
11441126
const float32x4_t v = vmulq_n_f32(srcv[j], id);
11451127
const int32x4_t vi = vcvtnq_s32_f32(v);
11461128

@@ -1149,14 +1131,10 @@ static void quantize_row_q8_1(const float * restrict x, void * restrict vy, int
11491131
y[i].qs[ 2*j + 1] = vgetq_lane_s32(vi, 2);
11501132
y[i].qs[16 + 2*j + 1] = vgetq_lane_s32(vi, 3);
11511133

1152-
accv1 = vaddq_s32(accv1, vi);
1134+
accv = vaddq_s32(accv, vi);
11531135
}
11541136

1155-
const int32_t sum0 = vaddvq_s32(accv0);
1156-
const int32_t sum1 = vaddvq_s32(accv1);
1157-
1158-
y[i].s0 = d * sum0;
1159-
y[i].s1 = d * sum1;
1137+
y[i].s = d * vaddvq_s32(accv);
11601138
}
11611139
#elif defined(__AVX2__) || defined(__AVX__)
11621140
for (int i = 0; i < nb; i++) {
@@ -1205,9 +1183,7 @@ static void quantize_row_q8_1(const float * restrict x, void * restrict vy, int
12051183

12061184
#if defined(__AVX2__)
12071185
// Compute the sum of the quants and set y[i].s
1208-
//y[i].s = d * hsum_i32_8(_mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3)));
1209-
y[i].s0 = d * hsum_i32_8(_mm256_add_epi32(i0, i1));
1210-
y[i].s1 = d * hsum_i32_8(_mm256_add_epi32(i2, i3));
1186+
y[i].s = d * hsum_i32_8(_mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3)));
12111187

12121188
// Convert int32 to int16
12131189
i0 = _mm256_packs_epi32( i0, i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15
@@ -1237,8 +1213,7 @@ static void quantize_row_q8_1(const float * restrict x, void * restrict vy, int
12371213
// Compute the sum of the quants and set y[i].s
12381214
const __m128i s0 = _mm_add_epi32(_mm_add_epi32(ni0, ni1), _mm_add_epi32(ni2, ni3));
12391215
const __m128i s1 = _mm_add_epi32(_mm_add_epi32(ni4, ni5), _mm_add_epi32(ni6, ni7));
1240-
y[i].s0 = d * hsum_i32_4(s0);
1241-
y[i].s1 = d * hsum_i32_4(s1);
1216+
y[i].s = d * hsum_i32_4(_mm_add_epi32(s0, s1));
12421217

12431218
// Convert int32 to int16
12441219
ni0 = _mm_packs_epi32( ni0, ni1 );
@@ -2200,7 +2175,7 @@ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void *
22002175
const block_q8_1 * restrict y0 = &y[i + 0];
22012176
const block_q8_1 * restrict y1 = &y[i + 1];
22022177

2203-
summs += x0->m * (y0->s0 + y0->s1) + x1->m * (y1->s0 + y1->s1);
2178+
summs += x0->m * y0->s + x1->m * y1->s;
22042179

22052180
const uint8x16_t m4b = vdupq_n_u8(0x0F);
22062181

@@ -2259,7 +2234,7 @@ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void *
22592234
const float * d0 = &x[i].d;
22602235
const float * d1 = &y[i].d;
22612236

2262-
summs += x[i].m * (y[i].s0 + y[i].s1);
2237+
summs += x[i].m * y[i].s;
22632238

22642239
const __m256 d0v = _mm256_broadcast_ss( d0 );
22652240
const __m256 d1v = _mm256_broadcast_ss( d1 );
@@ -2292,7 +2267,7 @@ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void *
22922267
sumi += (v0 * y[i].qs[j]) + (v1 * y[i].qs[j + qk/2]);
22932268
}
22942269

2295-
sumf += (x[i].d*y[i].d)*sumi + x[i].m*(y[i].s0 + y[i].s1);
2270+
sumf += (x[i].d*y[i].d)*sumi + x[i].m*y[i].s;
22962271
}
22972272

22982273
*s = sumf;
@@ -2545,8 +2520,8 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void *
25452520

25462521
const uint8x16_t m4b = vdupq_n_u8(0x0F);
25472522

2548-
summs0 += GGML_FP16_TO_FP32(x0->m) * (y0->s0 + y0->s1);
2549-
summs1 += GGML_FP16_TO_FP32(x1->m) * (y1->s0 + y1->s1);
2523+
summs0 += GGML_FP16_TO_FP32(x0->m) * y0->s;
2524+
summs1 += GGML_FP16_TO_FP32(x1->m) * y1->s;
25502525

25512526
// extract the 5th bit via lookup table ((b) << 4)
25522527
memcpy(&qh0, x0->qh, sizeof(qh0));
@@ -2632,7 +2607,7 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void *
26322607
const block_q5_1 * restrict x0 = &x[i];
26332608
const block_q8_1 * restrict y0 = &y[i];
26342609

2635-
summs += GGML_FP16_TO_FP32(x0->m) * (y0->s0 + y0->s1);
2610+
summs += GGML_FP16_TO_FP32(x0->m) * y0->s;
26362611

26372612
const v128_t m4b = wasm_i8x16_splat(0x0F);
26382613

@@ -2696,7 +2671,7 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void *
26962671
for (int i = 0; i < nb; i++) {
26972672
const __m256 dx = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d));
26982673

2699-
summs += GGML_FP16_TO_FP32(x[i].m) * (y[i].s0 + y[i].s1);
2674+
summs += GGML_FP16_TO_FP32(x[i].m) * y[i].s;
27002675

27012676
__m256i bx = bytes_from_nibbles_32(x[i].qs);
27022677
__m256i bxhi = bytes_from_bits_32(x[i].qh);
@@ -2732,7 +2707,7 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void *
27322707
sumi += (x0 * y[i].qs[j]) + (x1 * y[i].qs[j + qk/2]);
27332708
}
27342709

2735-
sumf += (GGML_FP16_TO_FP32(x[i].d)*y[i].d)*sumi + GGML_FP16_TO_FP32(x[i].m)*(y[i].s0 + y[i].s1);
2710+
sumf += (GGML_FP16_TO_FP32(x[i].d)*y[i].d)*sumi + GGML_FP16_TO_FP32(x[i].m)*y[i].s;
27362711
}
27372712

27382713
*s = sumf;

0 commit comments

Comments
 (0)