Skip to content

Commit 5fefd47

Browse files
committed
ggml-quants : ARM NEON vec_dot for q2_2 and q1_3
1 parent 5a64b71 commit 5fefd47

File tree

2 files changed

+182
-18
lines changed

2 files changed

+182
-18
lines changed

ggml-quants.c

Lines changed: 170 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -686,6 +686,10 @@ void quantize_row_q2_2_reference(const float * restrict x, block_q2_2 * restrict
686686
}
687687
}
688688

689+
void quantize_row_q2_2(const float * restrict x, void * restrict y, int64_t k) {
690+
quantize_row_q2_2_reference(x, y, k);
691+
}
692+
689693
// reference implementation for deterministic creation of model files
690694
void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int64_t k) {
691695
static const int qk = QK4_0;
@@ -3900,17 +3904,81 @@ void ggml_vec_dot_q2_2_q8_0(int n, float * restrict s, size_t bs, const void * r
39003904
}
39013905

39023906
*s = hsum_float_8(acc);
3907+
#elif defined(__ARM_NEON)
3908+
float sumf0 = 0.0f;
3909+
float sumf1 = 0.0f;
3910+
3911+
const uint8x8_t mask = vdup_n_u8(3);
3912+
const int8x8_t offset = vdup_n_s8(2);
3913+
3914+
const int leftovers = nb % 2;
3915+
3916+
for (int i = 0; i < nb - leftovers; i += 2) {
3917+
const uint8x8_t xq8_0 = vld1_u8(x[0].qs);
3918+
const uint8x8_t xq8_1 = vld1_u8(x[1].qs);
3919+
3920+
const int8x8_t xq8_0_0 = vsub_s8(vreinterpret_s8_u8(vand_u8(xq8_0, mask)), offset);
3921+
const int8x8_t xq8_0_1 = vsub_s8(vreinterpret_s8_u8(vand_u8(vshr_n_u8(xq8_0, 2), mask)), offset);
3922+
const int8x8_t xq8_0_2 = vsub_s8(vreinterpret_s8_u8(vand_u8(vshr_n_u8(xq8_0, 4), mask)), offset);
3923+
const int8x8_t xq8_0_3 = vsub_s8(vreinterpret_s8_u8(vand_u8(vshr_n_u8(xq8_0, 6), mask)), offset);
3924+
const int8x8_t xq8_1_0 = vsub_s8(vreinterpret_s8_u8(vand_u8(xq8_1, mask)), offset);
3925+
const int8x8_t xq8_1_1 = vsub_s8(vreinterpret_s8_u8(vand_u8(vshr_n_u8(xq8_1, 2), mask)), offset);
3926+
const int8x8_t xq8_1_2 = vsub_s8(vreinterpret_s8_u8(vand_u8(vshr_n_u8(xq8_1, 4), mask)), offset);
3927+
const int8x8_t xq8_1_3 = vsub_s8(vreinterpret_s8_u8(vand_u8(vshr_n_u8(xq8_1, 6), mask)), offset);
3928+
3929+
const int8x16_t xq8_0_l = vcombine_s8(xq8_0_0, xq8_0_1);
3930+
const int8x16_t xq8_0_h = vcombine_s8(xq8_0_2, xq8_0_3);
3931+
const int8x16_t xq8_1_l = vcombine_s8(xq8_1_0, xq8_1_1);
3932+
const int8x16_t xq8_1_h = vcombine_s8(xq8_1_2, xq8_1_3);
3933+
3934+
const int8x16_t yq8_0_l = vld1q_s8(y[0].qs + 0);
3935+
const int8x16_t yq8_0_h = vld1q_s8(y[0].qs + 16);
3936+
const int8x16_t yq8_1_l = vld1q_s8(y[1].qs + 0);
3937+
const int8x16_t yq8_1_h = vld1q_s8(y[1].qs + 16);
3938+
3939+
const int16x8_t dot0 = vaddq_s16(vpaddlq_s8(vmulq_s8(xq8_0_l, yq8_0_l)), vpaddlq_s8(vmulq_s8(xq8_0_h, yq8_0_h)));
3940+
const int16x8_t dot1 = vaddq_s16(vpaddlq_s8(vmulq_s8(xq8_1_l, yq8_1_l)), vpaddlq_s8(vmulq_s8(xq8_1_h, yq8_1_h)));
3941+
3942+
sumf0 += GGML_FP16_TO_FP32(y[0].d) * (float) vaddlvq_s16(dot0);
3943+
sumf1 += GGML_FP16_TO_FP32(y[1].d) * (float) vaddlvq_s16(dot1);
3944+
x += 2;
3945+
y += 2;
3946+
}
3947+
3948+
// one block at a time
3949+
for (int i = nb - leftovers; i < nb; ++i) {
3950+
const uint8x8_t xq8 = vld1_u8(x->qs);
3951+
const int8x8_t xq8_0 = vsub_s8(vreinterpret_s8_u8(vand_u8(xq8, mask)), offset);
3952+
const int8x8_t xq8_1 = vsub_s8(vreinterpret_s8_u8(vand_u8(vshr_n_u8(xq8, 2), mask)), offset);
3953+
const int8x8_t xq8_2 = vsub_s8(vreinterpret_s8_u8(vand_u8(vshr_n_u8(xq8, 4), mask)), offset);
3954+
const int8x8_t xq8_3 = vsub_s8(vreinterpret_s8_u8(vand_u8(vshr_n_u8(xq8, 6), mask)), offset);
3955+
3956+
const int8x16_t xq8_l = vcombine_s8(xq8_0, xq8_1);
3957+
const int8x16_t xq8_h = vcombine_s8(xq8_2, xq8_3);
3958+
3959+
const int8x16_t yq8_l = vld1q_s8(y->qs + 0);
3960+
const int8x16_t yq8_h = vld1q_s8(y->qs + 16);
3961+
3962+
const int16x8_t dot0 = vpaddlq_s8(vmulq_s8(xq8_l, yq8_l));
3963+
const int16x8_t dot1 = vpaddlq_s8(vmulq_s8(xq8_h, yq8_h));
3964+
3965+
sumf0 += GGML_FP16_TO_FP32(y->d) * (float) vaddlvq_s16(vaddq_s16(dot0, dot1));
3966+
x += 1;
3967+
y += 1;
3968+
}
3969+
3970+
*s = sumf0 + sumf1;
39033971
#else
39043972

3905-
float sumf = 0.0;
3973+
float sumf = 0.0f;
39063974
for (int i = 0; i < nb; i++) {
39073975
int sumi = 0;
39083976
for (int j = 0; j < qk / 4; j++) {
39093977
const uint8_t weight = x[i].qs[j];
3910-
sumi += (int)y[i].qs[j + 0*qk/4] * ((weight >> 0) & 3) - 2;
3911-
sumi += (int)y[i].qs[j + 1*qk/4] * ((weight >> 2) & 3) - 2;
3912-
sumi += (int)y[i].qs[j + 2*qk/4] * ((weight >> 4) & 3) - 2;
3913-
sumi += (int)y[i].qs[j + 3*qk/4] * ((weight >> 6) & 3) - 2;
3978+
sumi += (int)y[i].qs[j + 0*qk/4] * (((weight >> 0) & 3) - 2);
3979+
sumi += (int)y[i].qs[j + 1*qk/4] * (((weight >> 2) & 3) - 2);
3980+
sumi += (int)y[i].qs[j + 2*qk/4] * (((weight >> 4) & 3) - 2);
3981+
sumi += (int)y[i].qs[j + 3*qk/4] * (((weight >> 6) & 3) - 2);
39143982
}
39153983
sumf += (float)(sumi)*(GGML_FP16_TO_FP32(y[i].d));
39163984
}
@@ -11314,27 +11382,27 @@ void ggml_vec_dot_q1_3_q8_0(int n, float * restrict s, size_t bs, const void * r
1131411382

1131511383
for (int i = 0; i < nb; ++i) {
1131611384
// const __m128i x12a = _mm_maskload_epi32((const int32_t *) x, _mm_set_epi32(0, -1, -1, -1));
11317-
// const __m128i x12b = _mm_insert_epi8(x12a, x->qs[0], 12);
11385+
// const __m128i x13b = _mm_insert_epi8(x12a, x->qs[0], 12);
1131811386
// WARNING: reading 3 bytes further than necessary.
1131911387
// It's measurably faster than a masked load on an Intel Core m3-8100Y
11320-
const __m128i x12b = _mm_loadu_si128((const __m128i_u *) x);
11321-
const __m256i x12 = MM256_SET_M128I(x12b, x12b);
11388+
const __m128i x13b = _mm_loadu_si128((const __m128i_u *) x);
11389+
const __m256i x13 = MM256_SET_M128I(x13b, x13b);
1132211390

1132311391
{
1132411392
// pre-shift the values by 8 bits, and prepare the layout for later packing
11325-
__m256i x0l = _mm256_shuffle_epi8(x12, _mm256_set_epi8(5, -1, 5, -1, 5, -1, 5, -1,
11393+
__m256i x0l = _mm256_shuffle_epi8(x13, _mm256_set_epi8(5, -1, 5, -1, 5, -1, 5, -1,
1132611394
4, -1, 4, -1, 4, -1, 4, -1,
1132711395
1, -1, 1, -1, 1, -1, 1, -1,
1132811396
0, -1, 0, -1, 0, -1, 0, -1));
11329-
__m256i x0h = _mm256_shuffle_epi8(x12, _mm256_set_epi8(7, -1, 7, -1, 7, -1, 7, -1,
11397+
__m256i x0h = _mm256_shuffle_epi8(x13, _mm256_set_epi8(7, -1, 7, -1, 7, -1, 7, -1,
1133011398
6, -1, 6, -1, 6, -1, 6, -1,
1133111399
3, -1, 3, -1, 3, -1, 3, -1,
1133211400
2, -1, 2, -1, 2, -1, 2, -1));
11333-
__m256i x1l = _mm256_shuffle_epi8(x12, _mm256_set_epi8(7, -1, 6, -1, 5, -1, 4, -1,
11401+
__m256i x1l = _mm256_shuffle_epi8(x13, _mm256_set_epi8(7, -1, 6, -1, 5, -1, 4, -1,
1133411402
3, -1, 2, -1, 1, -1, 0, -1,
1133511403
9, -1, 9, -1, 9, -1, 9, -1,
1133611404
8, -1, 8, -1, 8, -1, 8, -1));
11337-
__m256i x1h = _mm256_shuffle_epi8(x12, _mm256_set_epi8(12, -1, 12, -1, 12, -1, 12, -1,
11405+
__m256i x1h = _mm256_shuffle_epi8(x13, _mm256_set_epi8(12, -1, 12, -1, 12, -1, 12, -1,
1133811406
11, -1, 10, -1, 9, -1, 8, -1,
1133911407
11, -1, 11, -1, 11, -1, 11, -1,
1134011408
10, -1, 10, -1, 10, -1, 10, -1));
@@ -11385,6 +11453,88 @@ void ggml_vec_dot_q1_3_q8_0(int n, float * restrict s, size_t bs, const void * r
1138511453
}
1138611454

1138711455
*s = hsum_float_8(accumf);
11456+
#elif defined(__ARM_NEON)
11457+
11458+
static const uint8_t k_mask0[16] = {0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3};
11459+
static const uint8_t k_mask1[16] = {4, 4, 4, 4, 5, 5, 5, 5, 6, 6, 6, 6, 7, 7, 7, 7};
11460+
static const uint8_t k_mask2[16] = {8, 8, 8, 8, 9, 9, 9, 9, 10, 10, 10, 10, 11, 11, 11, 11};
11461+
static const uint8_t k_mask3[16] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 12, 12, 12};
11462+
11463+
static const uint8_t k_shift0[16] = {81, 27, 9, 3, 81, 27, 9, 3, 81, 27, 9, 3, 81, 27, 9, 3};
11464+
static const uint8_t k_shift3[16] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 81, 27, 9, 3};
11465+
11466+
// float32x4_t sumv0 = vdupq_n_f32(0.0f);
11467+
// float32x4_t sumv1 = vdupq_n_f32(0.0f);
11468+
11469+
float sumf0 = 0.0f;
11470+
float sumf1 = 0.0f;
11471+
11472+
const uint8x16_t mask0 = vld1q_u8(k_mask0);
11473+
const uint8x16_t mask1 = vld1q_u8(k_mask1);
11474+
const uint8x16_t mask2 = vld1q_u8(k_mask2);
11475+
const uint8x16_t mask3 = vld1q_u8(k_mask3);
11476+
11477+
const uint8x16_t shift0 = vld1q_u8(k_shift0);
11478+
const uint8x16_t shift3 = vld1q_u8(k_shift3);
11479+
11480+
const int8x16_t one = vdupq_n_s8(1);
11481+
11482+
for (int i = 0; i < nb; ++i) {
11483+
// WARNING: reading 3 bytes further than necessary
11484+
const uint8x16_t x13b = vld1q_u8((const uint8_t *) x);
11485+
11486+
uint8x16_t x0 = vqtbl1q_u8(x13b, mask0);
11487+
uint8x16_t x1 = vqtbl1q_u8(x13b, mask1);
11488+
uint8x16_t x2 = vqtbl1q_u8(x13b, mask2);
11489+
uint8x16_t x3 = vqtbl1q_u8(x13b, mask3);
11490+
11491+
x0 = vmulq_u8(x0, shift0);
11492+
x1 = vmulq_u8(x1, shift0);
11493+
x2 = vmulq_u8(x2, shift0);
11494+
x3 = vmulq_u8(x3, shift3);
11495+
11496+
// multiply by 3 and keep the 2 bits above 8 bits
11497+
x0 = vshrq_n_u8(vhaddq_u8(x0, vshrq_n_u8(x0, 1)), 6);
11498+
x1 = vshrq_n_u8(vhaddq_u8(x1, vshrq_n_u8(x1, 1)), 6);
11499+
x2 = vshrq_n_u8(vhaddq_u8(x2, vshrq_n_u8(x2, 1)), 6);
11500+
x3 = vshrq_n_u8(vhaddq_u8(x3, vshrq_n_u8(x3, 1)), 6);
11501+
11502+
// 0, 1, 2 => -1, 0, 1
11503+
int8x16_t x0i = vsubq_s8(vreinterpretq_s8_u8(x0), one);
11504+
int8x16_t x1i = vsubq_s8(vreinterpretq_s8_u8(x1), one);
11505+
int8x16_t x2i = vsubq_s8(vreinterpretq_s8_u8(x2), one);
11506+
int8x16_t x3i = vsubq_s8(vreinterpretq_s8_u8(x3), one);
11507+
11508+
const int8x16_t y0 = vld1q_s8(y[0].qs + 0);
11509+
const int8x16_t y1 = vld1q_s8(y[0].qs + 16);
11510+
const int8x16_t y2 = vld1q_s8(y[1].qs + 0);
11511+
const int8x16_t y3 = vld1q_s8(y[1].qs + 16);
11512+
11513+
// const int32x4_t p0 = vpaddlq_s16(vaddq_s16(vpaddlq_s8(x0i), vpaddlq_s8(x1i)));
11514+
// const int32x4_t p1 = vpaddlq_s16(vaddq_s16(vpaddlq_s8(x2i), vpaddlq_s8(x3i)));
11515+
11516+
// there's no direct equivalent to _mm_sign_epi8, unfortunately
11517+
x0i = vmulq_s8(x0i, y0);
11518+
x1i = vmulq_s8(x1i, y1);
11519+
x2i = vmulq_s8(x2i, y2);
11520+
x3i = vmulq_s8(x3i, y3);
11521+
11522+
// overall 18.5% faster than with vector sums on a cortex-A72
11523+
sumf0 += GGML_FP16_TO_FP32(y[0].d) * (float) vaddlvq_s16(vaddq_s16(vpaddlq_s8(x0i), vpaddlq_s8(x1i)));
11524+
sumf1 += GGML_FP16_TO_FP32(y[1].d) * (float) vaddlvq_s16(vaddq_s16(vpaddlq_s8(x2i), vpaddlq_s8(x3i)));
11525+
11526+
// const int32x4_t p0 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), x0i, y0), x1i, y1);
11527+
// const int32x4_t p1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), x2i, y2), x3i, y3);
11528+
11529+
// sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p0), GGML_FP16_TO_FP32(y[0].d));
11530+
// sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p1), GGML_FP16_TO_FP32(y[1].d));
11531+
11532+
y += 2;
11533+
x += 1;
11534+
}
11535+
11536+
// *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
11537+
*s = sumf0 + sumf1;
1138811538
#else
1138911539
float sumf = 0.0f;
1139011540

@@ -11393,34 +11543,36 @@ void ggml_vec_dot_q1_3_q8_0(int n, float * restrict s, size_t bs, const void * r
1139311543
for (int j = 0; j < 8; ++j) {
1139411544
const int8_t * xj = (const int8_t *) (q1_3_grid + x[i].q[j]);
1139511545
for (int k = 0; k < 4; ++k) {
11396-
sum += xj[k] * (int16_t) y[2*i].qs[4*j + k];
11546+
sum += xj[k] * (int16_t) y->qs[4*j + k];
1139711547
}
1139811548
}
1139911549

11400-
sumf += GGML_FP16_TO_FP32(y[2*i].d) * sum;
11550+
sumf += GGML_FP16_TO_FP32(y->d) * sum;
11551+
y += 1;
1140111552
sum = 0;
1140211553

1140311554
for (int j = 0; j < 4; ++j) {
1140411555
const int8_t * xj = (const int8_t *) (q1_3_grid + x[i].q[8 + j]);
1140511556
for (int k = 0; k < 4; ++k) {
11406-
sum += xj[k] * (int16_t) y[2*i + 1].qs[4*j + k];
11557+
sum += xj[k] * (int16_t) y->qs[4*j + k];
1140711558
}
1140811559
}
1140911560

1141011561
for (size_t j = 0; j < 12; ++j) {
1141111562
uint16_t xj = x[i].q[j];
1141211563
xj = (xj * 3) >> 8;
11413-
sum += ((int16_t) xj - 1) * (int16_t) y[2*i + 1].qs[16 + j];
11564+
sum += ((int16_t) xj - 1) * (int16_t) y->qs[16 + j];
1141411565
}
1141511566

1141611567
{
1141711568
const int8_t * xj = (const int8_t *) (q1_3_grid + x[i].qs[0]);
1141811569
for (int k = 0; k < 4; ++k) {
11419-
sum += (int16_t) xj[k] * (int16_t) y[2*i + 1].qs[28 + k];
11570+
sum += (int16_t) xj[k] * (int16_t) y->qs[28 + k];
1142011571
}
1142111572
}
1142211573

11423-
sumf += GGML_FP16_TO_FP32(y[2*i + 1].d) * sum;
11574+
sumf += GGML_FP16_TO_FP32(y->d) * sum;
11575+
y += 1;
1142411576
}
1142511577

1142611578
*s = sumf;

ggml.c

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -836,6 +836,18 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
836836
.vec_dot_type = GGML_TYPE_Q8_K,
837837
.nrows = 1,
838838
},
839+
[GGML_TYPE_Q2_2] = {
840+
.type_name = "q2_2",
841+
.blck_size = QK2_2,
842+
.type_size = sizeof(block_q2_2),
843+
.is_quantized = true,
844+
.to_float = (ggml_to_float_t) dequantize_row_q2_2,
845+
.from_float = quantize_row_q2_2,
846+
.from_float_reference = (ggml_from_float_t) quantize_row_q2_2_reference,
847+
.vec_dot = ggml_vec_dot_q2_2_q8_0,
848+
.vec_dot_type = GGML_TYPE_Q8_0,
849+
.nrows = 1,
850+
},
839851
[GGML_TYPE_Q1_3] = {
840852
.type_name = "q1_3",
841853
.blck_size = QK1_3,

0 commit comments

Comments
 (0)