Skip to content

Commit c67cc98

Browse files
ggml: aarch64: implement SVE kernels for q4_K_q8_K vector dot (#11227)
* Add SVE support for q4_K_q8_K * Update ggml/src/ggml-cpu/ggml-cpu-quants.c change to use K_SCALE_SIZE Co-authored-by: Georgi Gerganov <[email protected]> --------- Co-authored-by: Georgi Gerganov <[email protected]>
1 parent adc5dd9 commit c67cc98

File tree

1 file changed

+82
-1
lines changed

1 file changed

+82
-1
lines changed

ggml/src/ggml-cpu/ggml-cpu-quants.c

Lines changed: 82 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5573,7 +5573,88 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * r
55735573

55745574
uint32_t utmp[4];
55755575

5576-
#ifdef __ARM_NEON
5576+
#ifdef __ARM_FEATURE_SVE
5577+
float sumf = 0;
5578+
for (int i = 0; i < nb; ++i) {
5579+
5580+
const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
5581+
const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
5582+
5583+
const int16x8_t q8sums = vpaddq_s16(vld1q_s16(y[i].bsums), vld1q_s16(y[i].bsums + 8));
5584+
5585+
memcpy(utmp, x[i].scales, K_SCALE_SIZE);
5586+
5587+
uint32x2_t mins8 = { 0 };
5588+
mins8 = vset_lane_u32(utmp[1] & kmask1, mins8, 0);
5589+
mins8 = vset_lane_u32(((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4), mins8, 1);
5590+
5591+
utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
5592+
utmp[0] &= kmask1;
5593+
5594+
const int16x8_t mins = vreinterpretq_s16_u16(vmovl_u8(vreinterpret_u8_u32(mins8)));
5595+
const int32x4_t prod = vaddq_s32(vmull_s16(vget_low_s16 (q8sums), vget_low_s16 (mins)),
5596+
vmull_s16(vget_high_s16(q8sums), vget_high_s16(mins)));
5597+
sumf -= dmin * vaddvq_s32(prod);
5598+
5599+
const uint8_t * scales = (const uint8_t *)utmp;
5600+
5601+
const uint8_t * restrict q4 = x[i].qs;
5602+
const int8_t * restrict q8 = y[i].qs;
5603+
5604+
const int vector_length = ggml_cpu_get_sve_cnt()*8;
5605+
const svuint8_t m4b = svdup_n_u8(0xf);
5606+
const svint32_t mzero = svdup_n_s32(0);
5607+
svint32_t sumi1 = svdup_n_s32(0);
5608+
svint32_t sumi1_1 = svdup_n_s32(0);
5609+
svint32_t sumi1_2 = svdup_n_s32(0);
5610+
svint32_t sumi2 = svdup_n_s32(0);
5611+
svint32_t sumi2_1 = svdup_n_s32(0);
5612+
svint32_t sumi2_2 = svdup_n_s32(0);
5613+
switch (vector_length) {
5614+
case 128:
5615+
{
5616+
for (int j = 0; j < QK_K/64; ++j) {
5617+
svint8_t q4bytes = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svld1_u8(svptrue_b8(), q4), m4b));
5618+
svint8_t q8bytes = svld1_s8(svptrue_b8(), q8); q8 += 16;
5619+
sumi1_1 = svmla_n_s32_x(svptrue_b32(), sumi1_1, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+0]);
5620+
q4bytes = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svld1_u8(svptrue_b8(), q4+16), m4b));
5621+
q8bytes = svld1_s8(svptrue_b8(), q8); q8 += 16;
5622+
sumi1_2 = svmla_n_s32_x(svptrue_b32(), sumi1_2, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+0]);
5623+
5624+
q4bytes = svreinterpret_s8_u8(svlsr_n_u8_x(svptrue_b8(), svld1_u8(svptrue_b8(), q4), 4));
5625+
q8bytes = svld1_s8(svptrue_b8(), q8); q8 += 16;
5626+
sumi2_1 = svmla_n_s32_x(svptrue_b32(), sumi2_1, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+1]);
5627+
q4bytes = svreinterpret_s8_u8(svlsr_n_u8_x(svptrue_b8(), svld1_u8(svptrue_b8(), q4+16), 4));
5628+
q8bytes = svld1_s8(svptrue_b8(), q8); q8 += 16;
5629+
sumi2_2 = svmla_n_s32_x(svptrue_b32(), sumi2_2, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+1]);
5630+
q4 += 32;
5631+
}
5632+
sumi1 = svadd_s32_x(svptrue_b32(), sumi1_1, sumi1_2);
5633+
sumi2 = svadd_s32_x(svptrue_b32(), sumi2_1, sumi2_2);
5634+
sumf += d * (svaddv_s32(svptrue_b32(), svadd_s32_x(svptrue_b32(), sumi1, sumi2)));
5635+
} break;
5636+
case 256:
5637+
case 512:
5638+
{
5639+
for (int j = 0; j < QK_K/64; ++j) {
5640+
const svuint8_t q4bits = svld1_u8(svptrue_pat_b8(SV_VL32), q4); q4 += 32;
5641+
svint8_t q4bytes = svreinterpret_s8_u8(svand_u8_x(svptrue_pat_b8(SV_VL32), q4bits, m4b));
5642+
svint8_t q8bytes = svld1_s8(svptrue_pat_b8(SV_VL32), q8); q8 += 32;
5643+
sumi1 = svmla_n_s32_x(svptrue_pat_b32(SV_VL8), sumi1, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+0]);
5644+
5645+
q4bytes = svreinterpret_s8_u8(svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q4bits, 4));
5646+
q8bytes = svld1_s8(svptrue_pat_b8(SV_VL32), q8); q8 += 32;
5647+
sumi2 = svmla_n_s32_x(svptrue_pat_b32(SV_VL8), sumi2, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+1]);
5648+
}
5649+
sumf += d * (svaddv_s32(svptrue_pat_b32(SV_VL8), svadd_s32_x(svptrue_pat_b32(SV_VL8), sumi1, sumi2)));
5650+
} break;
5651+
default:
5652+
assert(false && "Unsupported vector length");
5653+
break;
5654+
}
5655+
}
5656+
*s = sumf;
5657+
#elif __ARM_NEON
55775658
const uint8x16_t m4b = vdupq_n_u8(0xf);
55785659
const int32x4_t mzero = vdupq_n_s32(0);
55795660

0 commit comments

Comments
 (0)