Skip to content

Commit 3f7fdf2

Browse files
committed
ggml : Added WIP rvv q4_0_8x8 gemm
1 parent 9bfecf4 commit 3f7fdf2

File tree

1 file changed

+81
-36
lines changed

1 file changed

+81
-36
lines changed

ggml/src/ggml-aarch64.c

Lines changed: 81 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -994,56 +994,57 @@ void ggml_gemv_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void *
994994
#elif defined(__riscv_v_intrinsic)
995995
if (__riscv_vlenb() >= QK4_0) {
996996
const size_t vl = QK4_0;
997-
const vuint8m1_t lhs_idx_m1 = __riscv_vand_vx_u8m1(__riscv_vid_v_u8m1(vl), 7, vl);
998-
const vuint8m2_t lhs_idx_m2 = __riscv_vcreate_v_u8m1_u8m2(lhs_idx_m1, lhs_idx_m1);
999-
const vuint8m2_t lhs_idx_m2_hi = __riscv_vadd_vx_u8m2(lhs_idx_m2, 8, vl);
1000-
const vuint8m4_t lhs_idx_m4 = __riscv_vcreate_v_u8m2_u8m4(lhs_idx_m2, lhs_idx_m2_hi);
1001-
const vbool2_t mask0 = __riscv_vreinterpret_v_u16m1_b2(__riscv_vreinterpret_v_u64m1_u16m1(__riscv_vmv_v_x_u64m1(0x00000000000000FFull, vl / 8)));
1002-
const vbool2_t mask1 = __riscv_vreinterpret_v_u16m1_b2(__riscv_vreinterpret_v_u64m1_u16m1(__riscv_vmv_v_x_u64m1(0x000000000000FF00ull, vl / 8)));
1003-
const vbool2_t mask2 = __riscv_vreinterpret_v_u16m1_b2(__riscv_vreinterpret_v_u64m1_u16m1(__riscv_vmv_v_x_u64m1(0x0000000000FF0000ull, vl / 8)));
1004-
const vbool2_t mask3 = __riscv_vreinterpret_v_u16m1_b2(__riscv_vreinterpret_v_u64m1_u16m1(__riscv_vmv_v_x_u64m1(0x00000000FF000000ull, vl / 8)));
1005-
const vbool2_t mask4 = __riscv_vreinterpret_v_u16m1_b2(__riscv_vreinterpret_v_u64m1_u16m1(__riscv_vmv_v_x_u64m1(0x000000FF00000000ull, vl / 8)));
1006-
const vbool2_t mask5 = __riscv_vreinterpret_v_u16m1_b2(__riscv_vreinterpret_v_u64m1_u16m1(__riscv_vmv_v_x_u64m1(0x0000FF0000000000ull, vl / 8)));
1007-
const vbool2_t mask6 = __riscv_vreinterpret_v_u16m1_b2(__riscv_vreinterpret_v_u64m1_u16m1(__riscv_vmv_v_x_u64m1(0x00FF000000000000ull, vl / 8)));
1008-
const vbool2_t mask7 = __riscv_vreinterpret_v_u16m1_b2(__riscv_vreinterpret_v_u64m1_u16m1(__riscv_vmv_v_x_u64m1(0xFF00000000000000ull, vl / 8)));
997+
const uint8_t mask_bytes[] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xFF};
998+
const vbool4_t mask7 = __riscv_vlm_v_b4(mask_bytes, vl * 2);
999+
const vint32m1_t iaccz = __riscv_vmv_v_x_i32m1(0, vl / 4);
10091000

10101001
const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
10111002
for (int x = 0; x < nc / ncols_interleaved; x++) {
10121003
const block_q4_0x8 * b_ptr = (const block_q4_0x8 *) vx + (x * nb);
10131004

10141005
vfloat32m1_t sumf = __riscv_vfmv_v_f_f32m1(0.0, vl / 4);
10151006
for (int l = 0; l < nb; l++) {
1016-
const vint8m1_t lhs_raw_vec = __riscv_vle8_v_i8m1(a_ptr[l].qs, vl);
1017-
const vint8m4_t lhs_raw_vec_lo = __riscv_vset_v_i8m1_i8m4(__riscv_vundefined_i8m4(), 0, lhs_raw_vec);
1018-
const vint8m4_t lhs_raw_vec_hi = __riscv_vset_v_i8m1_i8m4(__riscv_vundefined_i8m4(), 0, __riscv_vslidedown_vx_i8m1(lhs_raw_vec, 16, vl));
1019-
const vint8m4_t lhs_vec_lo = __riscv_vrgather_vv_i8m4(lhs_raw_vec_lo, lhs_idx_m4, vl * 4);
1020-
const vint8m4_t lhs_vec_hi = __riscv_vrgather_vv_i8m4(lhs_raw_vec_hi, lhs_idx_m4, vl * 4);
1007+
const vint8m1_t lhs_0_4 = __riscv_vreinterpret_v_i64m1_i8m1(__riscv_vmv_v_x_i64m1(*(int64_t *)&a_ptr[l].qs[0], vl / 8));
1008+
const vint8m1_t lhs_1_4 = __riscv_vreinterpret_v_i64m1_i8m1(__riscv_vmv_v_x_i64m1(*(int64_t *)&a_ptr[l].qs[8], vl / 8));
1009+
const vint8m1_t lhs_2_4 = __riscv_vreinterpret_v_i64m1_i8m1(__riscv_vmv_v_x_i64m1(*(int64_t *)&a_ptr[l].qs[16], vl / 8));
1010+
const vint8m1_t lhs_3_4 = __riscv_vreinterpret_v_i64m1_i8m1(__riscv_vmv_v_x_i64m1(*(int64_t *)&a_ptr[l].qs[24], vl / 8));
1011+
const vint8m4_t lhs_vec_lo = __riscv_vcreate_v_i8m1_i8m4(lhs_0_4, lhs_0_4, lhs_1_4, lhs_1_4);
1012+
const vint8m4_t lhs_vec_hi = __riscv_vcreate_v_i8m1_i8m4(lhs_2_4, lhs_2_4, lhs_3_4, lhs_3_4);
10211013

10221014
const vint8m4_t rhs_raw_vec = __riscv_vle8_v_i8m4((const int8_t *)b_ptr[l].qs, vl * 4);
10231015
const vint8m4_t rhs_vec_lo = __riscv_vsra_vx_i8m4(__riscv_vsll_vx_i8m4(rhs_raw_vec, 4, vl * 4), 4, vl * 4);
10241016
const vint8m4_t rhs_vec_hi = __riscv_vsra_vx_i8m4(rhs_raw_vec, 4, vl * 4);
10251017

10261018
const vint16m8_t sumi_lo = __riscv_vwmul_vv_i16m8(rhs_vec_lo, lhs_vec_lo, vl * 4);
10271019
const vint16m8_t sumi_hi = __riscv_vwmul_vv_i16m8(rhs_vec_hi, lhs_vec_hi, vl * 4);
1028-
const vint16m8_t sumi = __riscv_vadd_vv_i16m8(sumi_lo, sumi_hi, vl * 4);
1029-
1030-
const vint32m1_t iaccz = __riscv_vmv_v_x_i32m1(0, vl / 4);
1031-
const vint32m1_t iacc7 = __riscv_vwredsum_vs_i16m8_i32m1_m(mask7, sumi, iaccz, vl * 4);
1032-
const vint32m1_t iacc7s = __riscv_vslideup_vx_i32m1(iacc7, iacc7, 1, vl / 4);
1033-
const vint32m1_t iacc6 = __riscv_vwredsum_vs_i16m8_i32m1_tum(mask6, iacc7s, sumi, iaccz, vl * 4);
1034-
const vint32m1_t iacc6s = __riscv_vslideup_vx_i32m1(iacc6, iacc6, 1, vl / 4);
1035-
const vint32m1_t iacc5 = __riscv_vwredsum_vs_i16m8_i32m1_tum(mask5, iacc6s, sumi, iaccz, vl * 4);
1036-
const vint32m1_t iacc5s = __riscv_vslideup_vx_i32m1(iacc5, iacc5, 1, vl / 4);
1037-
const vint32m1_t iacc4 = __riscv_vwredsum_vs_i16m8_i32m1_tum(mask4, iacc5s, sumi, iaccz, vl * 4);
1038-
const vint32m1_t iacc4s = __riscv_vslideup_vx_i32m1(iacc4, iacc4, 1, vl / 4);
1039-
const vint32m1_t iacc3 = __riscv_vwredsum_vs_i16m8_i32m1_tum(mask3, iacc4s, sumi, iaccz, vl * 4);
1040-
const vint32m1_t iacc3s = __riscv_vslideup_vx_i32m1(iacc3, iacc3, 1, vl / 4);
1041-
const vint32m1_t iacc2 = __riscv_vwredsum_vs_i16m8_i32m1_tum(mask2, iacc3s, sumi, iaccz, vl * 4);
1042-
const vint32m1_t iacc2s = __riscv_vslideup_vx_i32m1(iacc2, iacc2, 1, vl / 4);
1043-
const vint32m1_t iacc1 = __riscv_vwredsum_vs_i16m8_i32m1_tum(mask1, iacc2s, sumi, iaccz, vl * 4);
1044-
const vint32m1_t iacc1s = __riscv_vslideup_vx_i32m1(iacc1, iacc1, 1, vl / 4);
1045-
const vint32m1_t iacc0 = __riscv_vwredsum_vs_i16m8_i32m1_tum(mask0, iacc1s, sumi, iaccz, vl * 4);
1046-
const vfloat32m1_t facc = __riscv_vfcvt_f_x_v_f32m1(iacc0, vl / 4);
1020+
const vint16m8_t sumi2 = __riscv_vadd_vv_i16m8(sumi_lo, sumi_hi, vl * 4);
1021+
const vint16m4_t sumi2_lo = __riscv_vget_v_i16m8_i16m4(sumi2, 0);
1022+
const vint16m4_t sumi2_hi = __riscv_vget_v_i16m8_i16m4(sumi2, 1);
1023+
const vint16m4_t sumi = __riscv_vadd_vv_i16m4(sumi2_lo, sumi2_hi, vl * 2);
1024+
1025+
const vint32m1_t iacc7 = __riscv_vwredsum_vs_i16m4_i32m1_m(mask7, sumi, iaccz, vl * 2);
1026+
sumi = __riscv_vslideup_vx_i16m4(sumi, sumi, 8, vl * 2);
1027+
const vint32m1_t iacc6 = __riscv_vwredsum_vs_i16m4_i32m1_m(mask7, sumi, iaccz, vl * 2);
1028+
sumi = __riscv_vslideup_vx_i16m4(sumi, sumi, 8, vl * 2);
1029+
const vint32m1_t iacc5 = __riscv_vwredsum_vs_i16m4_i32m1_m(mask7, sumi, iaccz, vl * 2);
1030+
sumi = __riscv_vslideup_vx_i16m4(sumi, sumi, 8, vl * 2);
1031+
const vint32m1_t iacc4 = __riscv_vwredsum_vs_i16m4_i32m1_m(mask7, sumi, iaccz, vl * 2);
1032+
sumi = __riscv_vslideup_vx_i16m4(sumi, sumi, 8, vl * 2);
1033+
const vint32m1_t iacc3 = __riscv_vwredsum_vs_i16m4_i32m1_m(mask7, sumi, iaccz, vl * 2);
1034+
sumi = __riscv_vslideup_vx_i16m4(sumi, sumi, 8, vl * 2);
1035+
const vint32m1_t iacc2 = __riscv_vwredsum_vs_i16m4_i32m1_m(mask7, sumi, iaccz, vl * 2);
1036+
sumi = __riscv_vslideup_vx_i16m4(sumi, sumi, 8, vl * 2);
1037+
const vint32m1_t iacc1 = __riscv_vwredsum_vs_i16m4_i32m1_m(mask7, sumi, iaccz, vl * 2);
1038+
sumi = __riscv_vslideup_vx_i16m4(sumi, sumi, 8, vl * 2);
1039+
const vint32m1_t iacc0 = __riscv_vwredsum_vs_i16m4_i32m1_m(mask7, sumi, iaccz, vl * 2);
1040+
const vint32m1_t iacc6s = __riscv_vslide1up_vx_i32m1(iacc7, __riscv_vmv_x_s_i32m1_i32(iacc6), vl / 4);
1041+
const vint32m1_t iacc5s = __riscv_vslide1up_vx_i32m1(iacc6s, __riscv_vmv_x_s_i32m1_i32(iacc5), vl / 4);
1042+
const vint32m1_t iacc4s = __riscv_vslide1up_vx_i32m1(iacc5s, __riscv_vmv_x_s_i32m1_i32(iacc4), vl / 4);
1043+
const vint32m1_t iacc3s = __riscv_vslide1up_vx_i32m1(iacc4s, __riscv_vmv_x_s_i32m1_i32(iacc3), vl / 4);
1044+
const vint32m1_t iacc2s = __riscv_vslide1up_vx_i32m1(iacc3s, __riscv_vmv_x_s_i32m1_i32(iacc2), vl / 4);
1045+
const vint32m1_t iacc1s = __riscv_vslide1up_vx_i32m1(iacc2s, __riscv_vmv_x_s_i32m1_i32(iacc1), vl / 4);
1046+
const vint32m1_t iacc0s = __riscv_vslide1up_vx_i32m1(iacc1s, __riscv_vmv_x_s_i32m1_i32(iacc0), vl / 4);
1047+
const vfloat32m1_t facc = __riscv_vfcvt_f_x_v_f32m1(iacc0s, vl / 4);
10471048

10481049
// vector version needs Zvfhmin extension
10491050
const float a_scale = GGML_FP16_TO_FP32(a_ptr[l].d);
@@ -3246,6 +3247,50 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void *
32463247
}
32473248
}
32483249
}
3250+
return;
3251+
}
3252+
#elif defined(__riscv_v_intrinsic)
3253+
if (__riscv_vlenb() >= QK4_0) {
3254+
const size_t vl = QK4_0;
3255+
const uint8_t mask_bytes[] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xFF};
3256+
const vbool4_t mask7 = __riscv_vlm_v_b4(mask_bytes, vl * 2);
3257+
3258+
for (int y = 0; y < nr / 4; y++) {
3259+
const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
3260+
for (int x = 0; x < nc / ncols_interleaved; x++) {
3261+
const block_q4_0x8 * b_ptr = (const block_q4_0x8 *) vx + (x * nb);
3262+
// for (int m = 0; m < 4; m++) {
3263+
// for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0;
3264+
// }
3265+
vfloat32m1_t sumf0 = __riscv_vfmv_v_f_f32m1(0.0, vl / 4);
3266+
vfloat32m1_t sumf1 = __riscv_vfmv_v_f_f32m1(0.0, vl / 4);
3267+
vfloat32m1_t sumf2 = __riscv_vfmv_v_f_f32m1(0.0, vl / 4);
3268+
vfloat32m1_t sumf3 = __riscv_vfmv_v_f_f32m1(0.0, vl / 4);
3269+
for (int l = 0; l < nb; l++) {
3270+
const vint8m4_t rhs_raw_vec = __riscv_vle8_v_i8m4((const int8_t *)b_ptr[l].qs, vl * 4);
3271+
const vint8m4_t rhs_vec_lo = __riscv_vsra_vx_i8m4(__riscv_vsll_vx_i8m4(rhs_raw_vec, 4, vl * 4), 4, vl * 4);
3272+
const vint8m4_t rhs_vec_hi = __riscv_vsra_vx_i8m4(rhs_raw_vec, 4, vl * 4);
3273+
3274+
{
3275+
const vint8m1_t lhs_0_4 = __riscv_vreinterpret_v_i64m1_i8m1(__riscv_vmv_v_x_i64m1(*(int64_t *)&a_ptr[l].qs[0], vl / 8));
3276+
const vint8m1_t lhs_1_4 = __riscv_vreinterpret_v_i64m1_i8m1(__riscv_vmv_v_x_i64m1(*(int64_t *)&a_ptr[l].qs[32], vl / 8));
3277+
const vint8m1_t lhs_2_4 = __riscv_vreinterpret_v_i64m1_i8m1(__riscv_vmv_v_x_i64m1(*(int64_t *)&a_ptr[l].qs[64], vl / 8));
3278+
const vint8m1_t lhs_3_4 = __riscv_vreinterpret_v_i64m1_i8m1(__riscv_vmv_v_x_i64m1(*(int64_t *)&a_ptr[l].qs[96], vl / 8));
3279+
const vint8m4_t lhs_vec_lo = __riscv_vcreate_v_i8m1_i8m4(lhs_0_4, lhs_0_4, lhs_1_4, lhs_1_4);
3280+
const vint8m4_t lhs_vec_hi = __riscv_vcreate_v_i8m1_i8m4(lhs_2_4, lhs_2_4, lhs_3_4, lhs_3_4);
3281+
}
3282+
}
3283+
// for (int m = 0; m < 4; m++) {
3284+
// for (int j = 0; j < ncols_interleaved; j++)
3285+
// s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j];
3286+
// }
3287+
__riscv_vse32_v_f32m1(&s[(y * 4 + 0) * bs + x * ncols_interleaved], sumf0, vl / 4);
3288+
__riscv_vse32_v_f32m1(&s[(y * 4 + 1) * bs + x * ncols_interleaved], sumf1, vl / 4);
3289+
__riscv_vse32_v_f32m1(&s[(y * 4 + 2) * bs + x * ncols_interleaved], sumf2, vl / 4);
3290+
__riscv_vse32_v_f32m1(&s[(y * 4 + 3) * bs + x * ncols_interleaved], sumf3, vl / 4);
3291+
}
3292+
}
3293+
32493294
return;
32503295
}
32513296
#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__)

0 commit comments

Comments
 (0)