Skip to content

ggml: aarch64: implement SVE kernels for q6_K_q8_K vector dot #12361

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 18, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
151 changes: 150 additions & 1 deletion ggml/src/ggml-cpu/ggml-cpu-quants.c
Original file line number Diff line number Diff line change
Expand Up @@ -8158,7 +8158,156 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi

const int nb = n / QK_K;

#ifdef __ARM_NEON
#ifdef __ARM_FEATURE_SVE
const int vector_length = ggml_cpu_get_sve_cnt()*8;
float sum = 0;
svuint8_t m4b = svdup_n_u8(0xf);
svint32_t vzero = svdup_n_s32(0);
svuint8_t mone = svdup_n_u8(0x30);
svint8_t q6bytes_1, q6bytes_2, q6bytes_3, q6bytes_4;
svuint8_t q6h_1, q6h_2, q6h_3, q6h_4;

for (int i = 0; i < nb; ++i) {
const float d_all = GGML_FP16_TO_FP32(x[i].d);

const uint8_t * GGML_RESTRICT q6 = x[i].ql;
const uint8_t * GGML_RESTRICT qh = x[i].qh;
const int8_t * GGML_RESTRICT q8 = y[i].qs;

const int8_t * GGML_RESTRICT scale = x[i].scales;

const svbool_t pg16_8 = svptrue_pat_b16(SV_VL8);
const svint16_t q8sums_1 = svld1_s16(pg16_8, y[i].bsums);
const svint16_t q8sums_2 = svld1_s16(pg16_8, y[i].bsums + 8);
const svint16_t q6scales_1 = svunpklo_s16(svld1_s8(svptrue_pat_b8(SV_VL8), scale));
const svint16_t q6scales_2 = svunpklo_s16(svld1_s8(svptrue_pat_b8(SV_VL8), scale + 8));
const svint64_t prod = svdup_n_s64(0);
int32_t isum_mins = svaddv_s64(svptrue_b64(), svadd_s64_x(svptrue_b64(), svdot_s64(prod, q8sums_1, q6scales_1),
svdot_s64(prod, q8sums_2, q6scales_2)));
int32_t isum = 0;

switch (vector_length) {
case 128:
{
const svbool_t pg32_4 = svptrue_pat_b32(SV_VL4);
const svbool_t pg8_16 = svptrue_pat_b8(SV_VL16);
svint32_t isum_tmp = svdup_n_s32(0);
for (int j = 0; j < QK_K/128; ++j) {
svuint8_t qhbits_1 = svld1_u8(pg8_16, qh);
svuint8_t qhbits_2 = svld1_u8(pg8_16, qh+16);
qh += 32;
svuint8_t q6bits_1 = svld1_u8(pg8_16, q6);
svuint8_t q6bits_2 = svld1_u8(pg8_16, q6+16);
svuint8_t q6bits_3 = svld1_u8(pg8_16, q6+32);
svuint8_t q6bits_4 = svld1_u8(pg8_16, q6+48);
q6 += 64;
svint8_t q8bytes_1 = svld1_s8(pg8_16, q8);
svint8_t q8bytes_2 = svld1_s8(pg8_16, q8+16);
svint8_t q8bytes_3 = svld1_s8(pg8_16, q8+32);
svint8_t q8bytes_4 = svld1_s8(pg8_16, q8+48);
q8 += 64;

q6h_1 = svand_u8_x(pg16_8, mone, svlsl_n_u8_x(pg16_8, qhbits_1, 4));
q6h_2 = svand_u8_x(pg16_8, mone, svlsl_n_u8_x(pg16_8, qhbits_2, 4));
q6h_3 = svand_u8_x(pg16_8, mone, svlsl_n_u8_x(pg16_8, qhbits_1, 2));
q6h_4 = svand_u8_x(pg16_8, mone, svlsl_n_u8_x(pg16_8, qhbits_2, 2));
q6bytes_1 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svand_u8_x(pg8_16, q6bits_1, m4b), q6h_1));
q6bytes_2 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svand_u8_x(pg8_16, q6bits_2, m4b), q6h_2));
q6bytes_3 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svand_u8_x(pg8_16, q6bits_3, m4b), q6h_3));
q6bytes_4 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svand_u8_x(pg8_16, q6bits_4, m4b), q6h_4));
isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_1, q8bytes_1), scale[0]);
isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_2, q8bytes_2), scale[1]);
isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_3, q8bytes_3), scale[2]);
isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_4, q8bytes_4), scale[3]);

scale += 4;
q8bytes_1 = svld1_s8(pg8_16, q8);
q8bytes_2 = svld1_s8(pg8_16, q8+16);
q8bytes_3 = svld1_s8(pg8_16, q8+32);
q8bytes_4 = svld1_s8(pg8_16, q8+48);
q8 += 64;

q6h_1 = svand_u8_x(pg16_8, mone, qhbits_1);
q6h_2 = svand_u8_x(pg16_8, mone, qhbits_2);
q6h_3 = svand_u8_x(pg16_8, mone, svlsr_n_u8_x(pg16_8, qhbits_1, 2));
q6h_4 = svand_u8_x(pg16_8, mone, svlsr_n_u8_x(pg16_8, qhbits_2, 2));
q6bytes_1 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svlsr_n_u8_x(pg8_16, q6bits_1, 4), q6h_1));
q6bytes_2 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svlsr_n_u8_x(pg8_16, q6bits_2, 4), q6h_2));
q6bytes_3 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svlsr_n_u8_x(pg8_16, q6bits_3, 4), q6h_3));
q6bytes_4 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svlsr_n_u8_x(pg8_16, q6bits_4, 4), q6h_4));
isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_1, q8bytes_1), scale[0]);
isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_2, q8bytes_2), scale[1]);
isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_3, q8bytes_3), scale[2]);
isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_4, q8bytes_4), scale[3]);
scale += 4;
}
isum += svaddv_s32(pg32_4, isum_tmp);
sum += d_all * y[i].d * (isum - 32 * isum_mins);
}
break;
case 256:
case 512:
{
const svbool_t pg8_2 = svptrue_pat_b8(SV_VL2);
const svbool_t pg32_8 = svptrue_pat_b32(SV_VL8);
const svbool_t pg8_32 = svptrue_pat_b8(SV_VL32);
svint32_t isum_tmp = svdup_n_s32(0);
for (int j = 0; j < QK_K/128; j++) {
svuint8_t qhbits_1 = svld1_u8(pg8_32, qh);
qh += 32;
svuint8_t q6bits_1 = svld1_u8(pg8_32, q6);
svuint8_t q6bits_2 = svld1_u8(pg8_32, q6+32);
q6 += 64;
svint8_t q8bytes_1 = svld1_s8(pg8_32, q8);
svint8_t q8bytes_2 = svld1_s8(pg8_32, q8+32);
svint8_t q8bytes_3 = svld1_s8(pg8_32, q8+64);
svint8_t q8bytes_4 = svld1_s8(pg8_32, q8+96);
q8 += 128;
q6h_1 = svand_u8_x(pg8_32, mone, svlsl_n_u8_x(pg8_32, qhbits_1, 4));
q6h_2 = svand_u8_x(pg8_32, mone, svlsl_n_u8_x(pg8_32, qhbits_1, 2));
q6h_3 = svand_u8_x(pg8_32, mone, qhbits_1);
q6h_4 = svand_u8_x(pg8_32, mone, svlsr_n_u8_x(pg8_32, qhbits_1, 2));
q6bytes_1 = svreinterpret_s8_u8(svorr_u8_x(pg8_32, svand_u8_x(pg8_32, q6bits_1, m4b), q6h_1));
q6bytes_2 = svreinterpret_s8_u8(svorr_u8_x(pg8_32, svand_u8_x(pg8_32, q6bits_2, m4b), q6h_2));
q6bytes_3 = svreinterpret_s8_u8(svorr_u8_x(pg8_32, svlsr_n_u8_x(pg8_32, q6bits_1, 4), q6h_3));
q6bytes_4 = svreinterpret_s8_u8(svorr_u8_x(pg8_32, svlsr_n_u8_x(pg8_32, q6bits_2, 4), q6h_4));

svint8_t scale_lane_1_tmp = svld1_s8(pg8_2, scale);
scale_lane_1_tmp= svzip1_s8(scale_lane_1_tmp, scale_lane_1_tmp);
scale_lane_1_tmp= svzip1_s8(scale_lane_1_tmp, scale_lane_1_tmp);
svint8_t scale_lane_2_tmp = svld1_s8(pg8_2, scale+2);
scale_lane_2_tmp = svzip1_s8(scale_lane_2_tmp, scale_lane_2_tmp);
scale_lane_2_tmp = svzip1_s8(scale_lane_2_tmp, scale_lane_2_tmp);
svint8_t scale_lane_3_tmp = svld1_s8(pg8_2, scale+4);
scale_lane_3_tmp = svzip1_s8(scale_lane_3_tmp, scale_lane_3_tmp);
scale_lane_3_tmp = svzip1_s8(scale_lane_3_tmp, scale_lane_3_tmp);
svint8_t scale_lane_4_tmp = svld1_s8(pg8_2, scale+6);
scale_lane_4_tmp = svzip1_s8(scale_lane_4_tmp, scale_lane_4_tmp);
scale_lane_4_tmp = svzip1_s8(scale_lane_4_tmp, scale_lane_4_tmp);
svint32_t scale_lane_1 = svunpklo_s32(svunpklo_s16(scale_lane_1_tmp));
svint32_t scale_lane_2 = svunpklo_s32(svunpklo_s16(scale_lane_2_tmp));
svint32_t scale_lane_3 = svunpklo_s32(svunpklo_s16(scale_lane_3_tmp));
svint32_t scale_lane_4 = svunpklo_s32(svunpklo_s16(scale_lane_4_tmp));

isum_tmp = svmla_s32_x(pg32_8, isum_tmp, svdot_s32(vzero, q6bytes_1, q8bytes_1), scale_lane_1);
isum_tmp = svmla_s32_x(pg32_8, isum_tmp, svdot_s32(vzero, q6bytes_2, q8bytes_2), scale_lane_2);
isum_tmp = svmla_s32_x(pg32_8, isum_tmp, svdot_s32(vzero, q6bytes_3, q8bytes_3), scale_lane_3);
isum_tmp = svmla_s32_x(pg32_8, isum_tmp, svdot_s32(vzero, q6bytes_4, q8bytes_4), scale_lane_4);
Comment on lines +8292 to +8295
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe something to try is to have 4 separate accumulators here. Don't have a machine that supports SVE to give this a try.

Copy link
Contributor Author

@fj-y-saito fj-y-saito Mar 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When I implemented this fix and measured the elapsed time of ggml_vec_dot_q6_K_q8_K in perf, I found a performance degradation of about 5%.
So I think it's better to leave it as it is.

I consider the following:
By providing a separate accumulator:

  • The seven dependencies on the critical path in the loop were reduced by one through the fix.
  • Three add instructions were added to sum up the sepalated accumulators outside the for loop (reducing the dependencies to two).

Instruction dependencies are reduced to 7->1, while the number of instructions is increased by 3.
If the number of loop rotations is large, the proposed modification is expected to improve performance. However, in this case, since the number of loop count was only 2, performance degradation due to the increase in the number of instructions was dominant.

scale += 8;
}
isum += svaddv_s32(pg32_8, isum_tmp);
sum += d_all * y[i].d * (isum - 32 * isum_mins);
}
break;
default:
assert(false && "Unsupported vector length");
break;
}
}

*s = sum;

#elif __ARM_NEON
float sum = 0;

const uint8x16_t m4b = vdupq_n_u8(0xF);
Expand Down
Loading