Skip to content

Commit d9a1452

Browse files
authored
ggml : add SVE support for q6_K_q8_K (ggml-org#12361)
1 parent fd123cf commit d9a1452

File tree

1 file changed

+150
-1
lines changed

1 file changed

+150
-1
lines changed

ggml/src/ggml-cpu/ggml-cpu-quants.c

Lines changed: 150 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8158,7 +8158,156 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
81588158

81598159
const int nb = n / QK_K;
81608160

8161-
#ifdef __ARM_NEON
8161+
#ifdef __ARM_FEATURE_SVE
8162+
const int vector_length = ggml_cpu_get_sve_cnt()*8;
8163+
float sum = 0;
8164+
svuint8_t m4b = svdup_n_u8(0xf);
8165+
svint32_t vzero = svdup_n_s32(0);
8166+
svuint8_t mone = svdup_n_u8(0x30);
8167+
svint8_t q6bytes_1, q6bytes_2, q6bytes_3, q6bytes_4;
8168+
svuint8_t q6h_1, q6h_2, q6h_3, q6h_4;
8169+
8170+
for (int i = 0; i < nb; ++i) {
8171+
const float d_all = GGML_FP16_TO_FP32(x[i].d);
8172+
8173+
const uint8_t * GGML_RESTRICT q6 = x[i].ql;
8174+
const uint8_t * GGML_RESTRICT qh = x[i].qh;
8175+
const int8_t * GGML_RESTRICT q8 = y[i].qs;
8176+
8177+
const int8_t * GGML_RESTRICT scale = x[i].scales;
8178+
8179+
const svbool_t pg16_8 = svptrue_pat_b16(SV_VL8);
8180+
const svint16_t q8sums_1 = svld1_s16(pg16_8, y[i].bsums);
8181+
const svint16_t q8sums_2 = svld1_s16(pg16_8, y[i].bsums + 8);
8182+
const svint16_t q6scales_1 = svunpklo_s16(svld1_s8(svptrue_pat_b8(SV_VL8), scale));
8183+
const svint16_t q6scales_2 = svunpklo_s16(svld1_s8(svptrue_pat_b8(SV_VL8), scale + 8));
8184+
const svint64_t prod = svdup_n_s64(0);
8185+
int32_t isum_mins = svaddv_s64(svptrue_b64(), svadd_s64_x(svptrue_b64(), svdot_s64(prod, q8sums_1, q6scales_1),
8186+
svdot_s64(prod, q8sums_2, q6scales_2)));
8187+
int32_t isum = 0;
8188+
8189+
switch (vector_length) {
8190+
case 128:
8191+
{
8192+
const svbool_t pg32_4 = svptrue_pat_b32(SV_VL4);
8193+
const svbool_t pg8_16 = svptrue_pat_b8(SV_VL16);
8194+
svint32_t isum_tmp = svdup_n_s32(0);
8195+
for (int j = 0; j < QK_K/128; ++j) {
8196+
svuint8_t qhbits_1 = svld1_u8(pg8_16, qh);
8197+
svuint8_t qhbits_2 = svld1_u8(pg8_16, qh+16);
8198+
qh += 32;
8199+
svuint8_t q6bits_1 = svld1_u8(pg8_16, q6);
8200+
svuint8_t q6bits_2 = svld1_u8(pg8_16, q6+16);
8201+
svuint8_t q6bits_3 = svld1_u8(pg8_16, q6+32);
8202+
svuint8_t q6bits_4 = svld1_u8(pg8_16, q6+48);
8203+
q6 += 64;
8204+
svint8_t q8bytes_1 = svld1_s8(pg8_16, q8);
8205+
svint8_t q8bytes_2 = svld1_s8(pg8_16, q8+16);
8206+
svint8_t q8bytes_3 = svld1_s8(pg8_16, q8+32);
8207+
svint8_t q8bytes_4 = svld1_s8(pg8_16, q8+48);
8208+
q8 += 64;
8209+
8210+
q6h_1 = svand_u8_x(pg16_8, mone, svlsl_n_u8_x(pg16_8, qhbits_1, 4));
8211+
q6h_2 = svand_u8_x(pg16_8, mone, svlsl_n_u8_x(pg16_8, qhbits_2, 4));
8212+
q6h_3 = svand_u8_x(pg16_8, mone, svlsl_n_u8_x(pg16_8, qhbits_1, 2));
8213+
q6h_4 = svand_u8_x(pg16_8, mone, svlsl_n_u8_x(pg16_8, qhbits_2, 2));
8214+
q6bytes_1 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svand_u8_x(pg8_16, q6bits_1, m4b), q6h_1));
8215+
q6bytes_2 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svand_u8_x(pg8_16, q6bits_2, m4b), q6h_2));
8216+
q6bytes_3 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svand_u8_x(pg8_16, q6bits_3, m4b), q6h_3));
8217+
q6bytes_4 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svand_u8_x(pg8_16, q6bits_4, m4b), q6h_4));
8218+
isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_1, q8bytes_1), scale[0]);
8219+
isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_2, q8bytes_2), scale[1]);
8220+
isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_3, q8bytes_3), scale[2]);
8221+
isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_4, q8bytes_4), scale[3]);
8222+
8223+
scale += 4;
8224+
q8bytes_1 = svld1_s8(pg8_16, q8);
8225+
q8bytes_2 = svld1_s8(pg8_16, q8+16);
8226+
q8bytes_3 = svld1_s8(pg8_16, q8+32);
8227+
q8bytes_4 = svld1_s8(pg8_16, q8+48);
8228+
q8 += 64;
8229+
8230+
q6h_1 = svand_u8_x(pg16_8, mone, qhbits_1);
8231+
q6h_2 = svand_u8_x(pg16_8, mone, qhbits_2);
8232+
q6h_3 = svand_u8_x(pg16_8, mone, svlsr_n_u8_x(pg16_8, qhbits_1, 2));
8233+
q6h_4 = svand_u8_x(pg16_8, mone, svlsr_n_u8_x(pg16_8, qhbits_2, 2));
8234+
q6bytes_1 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svlsr_n_u8_x(pg8_16, q6bits_1, 4), q6h_1));
8235+
q6bytes_2 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svlsr_n_u8_x(pg8_16, q6bits_2, 4), q6h_2));
8236+
q6bytes_3 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svlsr_n_u8_x(pg8_16, q6bits_3, 4), q6h_3));
8237+
q6bytes_4 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svlsr_n_u8_x(pg8_16, q6bits_4, 4), q6h_4));
8238+
isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_1, q8bytes_1), scale[0]);
8239+
isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_2, q8bytes_2), scale[1]);
8240+
isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_3, q8bytes_3), scale[2]);
8241+
isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_4, q8bytes_4), scale[3]);
8242+
scale += 4;
8243+
}
8244+
isum += svaddv_s32(pg32_4, isum_tmp);
8245+
sum += d_all * y[i].d * (isum - 32 * isum_mins);
8246+
}
8247+
break;
8248+
case 256:
8249+
case 512:
8250+
{
8251+
const svbool_t pg8_2 = svptrue_pat_b8(SV_VL2);
8252+
const svbool_t pg32_8 = svptrue_pat_b32(SV_VL8);
8253+
const svbool_t pg8_32 = svptrue_pat_b8(SV_VL32);
8254+
svint32_t isum_tmp = svdup_n_s32(0);
8255+
for (int j = 0; j < QK_K/128; j++) {
8256+
svuint8_t qhbits_1 = svld1_u8(pg8_32, qh);
8257+
qh += 32;
8258+
svuint8_t q6bits_1 = svld1_u8(pg8_32, q6);
8259+
svuint8_t q6bits_2 = svld1_u8(pg8_32, q6+32);
8260+
q6 += 64;
8261+
svint8_t q8bytes_1 = svld1_s8(pg8_32, q8);
8262+
svint8_t q8bytes_2 = svld1_s8(pg8_32, q8+32);
8263+
svint8_t q8bytes_3 = svld1_s8(pg8_32, q8+64);
8264+
svint8_t q8bytes_4 = svld1_s8(pg8_32, q8+96);
8265+
q8 += 128;
8266+
q6h_1 = svand_u8_x(pg8_32, mone, svlsl_n_u8_x(pg8_32, qhbits_1, 4));
8267+
q6h_2 = svand_u8_x(pg8_32, mone, svlsl_n_u8_x(pg8_32, qhbits_1, 2));
8268+
q6h_3 = svand_u8_x(pg8_32, mone, qhbits_1);
8269+
q6h_4 = svand_u8_x(pg8_32, mone, svlsr_n_u8_x(pg8_32, qhbits_1, 2));
8270+
q6bytes_1 = svreinterpret_s8_u8(svorr_u8_x(pg8_32, svand_u8_x(pg8_32, q6bits_1, m4b), q6h_1));
8271+
q6bytes_2 = svreinterpret_s8_u8(svorr_u8_x(pg8_32, svand_u8_x(pg8_32, q6bits_2, m4b), q6h_2));
8272+
q6bytes_3 = svreinterpret_s8_u8(svorr_u8_x(pg8_32, svlsr_n_u8_x(pg8_32, q6bits_1, 4), q6h_3));
8273+
q6bytes_4 = svreinterpret_s8_u8(svorr_u8_x(pg8_32, svlsr_n_u8_x(pg8_32, q6bits_2, 4), q6h_4));
8274+
8275+
svint8_t scale_lane_1_tmp = svld1_s8(pg8_2, scale);
8276+
scale_lane_1_tmp= svzip1_s8(scale_lane_1_tmp, scale_lane_1_tmp);
8277+
scale_lane_1_tmp= svzip1_s8(scale_lane_1_tmp, scale_lane_1_tmp);
8278+
svint8_t scale_lane_2_tmp = svld1_s8(pg8_2, scale+2);
8279+
scale_lane_2_tmp = svzip1_s8(scale_lane_2_tmp, scale_lane_2_tmp);
8280+
scale_lane_2_tmp = svzip1_s8(scale_lane_2_tmp, scale_lane_2_tmp);
8281+
svint8_t scale_lane_3_tmp = svld1_s8(pg8_2, scale+4);
8282+
scale_lane_3_tmp = svzip1_s8(scale_lane_3_tmp, scale_lane_3_tmp);
8283+
scale_lane_3_tmp = svzip1_s8(scale_lane_3_tmp, scale_lane_3_tmp);
8284+
svint8_t scale_lane_4_tmp = svld1_s8(pg8_2, scale+6);
8285+
scale_lane_4_tmp = svzip1_s8(scale_lane_4_tmp, scale_lane_4_tmp);
8286+
scale_lane_4_tmp = svzip1_s8(scale_lane_4_tmp, scale_lane_4_tmp);
8287+
svint32_t scale_lane_1 = svunpklo_s32(svunpklo_s16(scale_lane_1_tmp));
8288+
svint32_t scale_lane_2 = svunpklo_s32(svunpklo_s16(scale_lane_2_tmp));
8289+
svint32_t scale_lane_3 = svunpklo_s32(svunpklo_s16(scale_lane_3_tmp));
8290+
svint32_t scale_lane_4 = svunpklo_s32(svunpklo_s16(scale_lane_4_tmp));
8291+
8292+
isum_tmp = svmla_s32_x(pg32_8, isum_tmp, svdot_s32(vzero, q6bytes_1, q8bytes_1), scale_lane_1);
8293+
isum_tmp = svmla_s32_x(pg32_8, isum_tmp, svdot_s32(vzero, q6bytes_2, q8bytes_2), scale_lane_2);
8294+
isum_tmp = svmla_s32_x(pg32_8, isum_tmp, svdot_s32(vzero, q6bytes_3, q8bytes_3), scale_lane_3);
8295+
isum_tmp = svmla_s32_x(pg32_8, isum_tmp, svdot_s32(vzero, q6bytes_4, q8bytes_4), scale_lane_4);
8296+
scale += 8;
8297+
}
8298+
isum += svaddv_s32(pg32_8, isum_tmp);
8299+
sum += d_all * y[i].d * (isum - 32 * isum_mins);
8300+
}
8301+
break;
8302+
default:
8303+
assert(false && "Unsupported vector length");
8304+
break;
8305+
}
8306+
}
8307+
8308+
*s = sum;
8309+
8310+
#elif __ARM_NEON
81628311
float sum = 0;
81638312

81648313
const uint8x16_t m4b = vdupq_n_u8(0xF);

0 commit comments

Comments
 (0)