Skip to content

Commit 9bfecf4

Browse files
committed
ggml : RISC-V vector gemv for q4_0_8x8
1 parent 66c2c93 commit 9bfecf4

File tree

1 file changed

+75
-0
lines changed

1 file changed

+75
-0
lines changed

ggml/src/ggml-aarch64.c

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -991,6 +991,81 @@ void ggml_gemv_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void *
991991
}
992992
}
993993
return;
994+
#elif defined(__riscv_v_intrinsic)
995+
if (__riscv_vlenb() >= QK4_0) {
996+
const size_t vl = QK4_0;
997+
const vuint8m1_t lhs_idx_m1 = __riscv_vand_vx_u8m1(__riscv_vid_v_u8m1(vl), 7, vl);
998+
const vuint8m2_t lhs_idx_m2 = __riscv_vcreate_v_u8m1_u8m2(lhs_idx_m1, lhs_idx_m1);
999+
const vuint8m2_t lhs_idx_m2_hi = __riscv_vadd_vx_u8m2(lhs_idx_m2, 8, vl);
1000+
const vuint8m4_t lhs_idx_m4 = __riscv_vcreate_v_u8m2_u8m4(lhs_idx_m2, lhs_idx_m2_hi);
1001+
const vbool2_t mask0 = __riscv_vreinterpret_v_u16m1_b2(__riscv_vreinterpret_v_u64m1_u16m1(__riscv_vmv_v_x_u64m1(0x00000000000000FFull, vl / 8)));
1002+
const vbool2_t mask1 = __riscv_vreinterpret_v_u16m1_b2(__riscv_vreinterpret_v_u64m1_u16m1(__riscv_vmv_v_x_u64m1(0x000000000000FF00ull, vl / 8)));
1003+
const vbool2_t mask2 = __riscv_vreinterpret_v_u16m1_b2(__riscv_vreinterpret_v_u64m1_u16m1(__riscv_vmv_v_x_u64m1(0x0000000000FF0000ull, vl / 8)));
1004+
const vbool2_t mask3 = __riscv_vreinterpret_v_u16m1_b2(__riscv_vreinterpret_v_u64m1_u16m1(__riscv_vmv_v_x_u64m1(0x00000000FF000000ull, vl / 8)));
1005+
const vbool2_t mask4 = __riscv_vreinterpret_v_u16m1_b2(__riscv_vreinterpret_v_u64m1_u16m1(__riscv_vmv_v_x_u64m1(0x000000FF00000000ull, vl / 8)));
1006+
const vbool2_t mask5 = __riscv_vreinterpret_v_u16m1_b2(__riscv_vreinterpret_v_u64m1_u16m1(__riscv_vmv_v_x_u64m1(0x0000FF0000000000ull, vl / 8)));
1007+
const vbool2_t mask6 = __riscv_vreinterpret_v_u16m1_b2(__riscv_vreinterpret_v_u64m1_u16m1(__riscv_vmv_v_x_u64m1(0x00FF000000000000ull, vl / 8)));
1008+
const vbool2_t mask7 = __riscv_vreinterpret_v_u16m1_b2(__riscv_vreinterpret_v_u64m1_u16m1(__riscv_vmv_v_x_u64m1(0xFF00000000000000ull, vl / 8)));
1009+
1010+
const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
1011+
for (int x = 0; x < nc / ncols_interleaved; x++) {
1012+
const block_q4_0x8 * b_ptr = (const block_q4_0x8 *) vx + (x * nb);
1013+
1014+
vfloat32m1_t sumf = __riscv_vfmv_v_f_f32m1(0.0, vl / 4);
1015+
for (int l = 0; l < nb; l++) {
1016+
const vint8m1_t lhs_raw_vec = __riscv_vle8_v_i8m1(a_ptr[l].qs, vl);
1017+
const vint8m4_t lhs_raw_vec_lo = __riscv_vset_v_i8m1_i8m4(__riscv_vundefined_i8m4(), 0, lhs_raw_vec);
1018+
const vint8m4_t lhs_raw_vec_hi = __riscv_vset_v_i8m1_i8m4(__riscv_vundefined_i8m4(), 0, __riscv_vslidedown_vx_i8m1(lhs_raw_vec, 16, vl));
1019+
const vint8m4_t lhs_vec_lo = __riscv_vrgather_vv_i8m4(lhs_raw_vec_lo, lhs_idx_m4, vl * 4);
1020+
const vint8m4_t lhs_vec_hi = __riscv_vrgather_vv_i8m4(lhs_raw_vec_hi, lhs_idx_m4, vl * 4);
1021+
1022+
const vint8m4_t rhs_raw_vec = __riscv_vle8_v_i8m4((const int8_t *)b_ptr[l].qs, vl * 4);
1023+
const vint8m4_t rhs_vec_lo = __riscv_vsra_vx_i8m4(__riscv_vsll_vx_i8m4(rhs_raw_vec, 4, vl * 4), 4, vl * 4);
1024+
const vint8m4_t rhs_vec_hi = __riscv_vsra_vx_i8m4(rhs_raw_vec, 4, vl * 4);
1025+
1026+
const vint16m8_t sumi_lo = __riscv_vwmul_vv_i16m8(rhs_vec_lo, lhs_vec_lo, vl * 4);
1027+
const vint16m8_t sumi_hi = __riscv_vwmul_vv_i16m8(rhs_vec_hi, lhs_vec_hi, vl * 4);
1028+
const vint16m8_t sumi = __riscv_vadd_vv_i16m8(sumi_lo, sumi_hi, vl * 4);
1029+
1030+
const vint32m1_t iaccz = __riscv_vmv_v_x_i32m1(0, vl / 4);
1031+
const vint32m1_t iacc7 = __riscv_vwredsum_vs_i16m8_i32m1_m(mask7, sumi, iaccz, vl * 4);
1032+
const vint32m1_t iacc7s = __riscv_vslideup_vx_i32m1(iacc7, iacc7, 1, vl / 4);
1033+
const vint32m1_t iacc6 = __riscv_vwredsum_vs_i16m8_i32m1_tum(mask6, iacc7s, sumi, iaccz, vl * 4);
1034+
const vint32m1_t iacc6s = __riscv_vslideup_vx_i32m1(iacc6, iacc6, 1, vl / 4);
1035+
const vint32m1_t iacc5 = __riscv_vwredsum_vs_i16m8_i32m1_tum(mask5, iacc6s, sumi, iaccz, vl * 4);
1036+
const vint32m1_t iacc5s = __riscv_vslideup_vx_i32m1(iacc5, iacc5, 1, vl / 4);
1037+
const vint32m1_t iacc4 = __riscv_vwredsum_vs_i16m8_i32m1_tum(mask4, iacc5s, sumi, iaccz, vl * 4);
1038+
const vint32m1_t iacc4s = __riscv_vslideup_vx_i32m1(iacc4, iacc4, 1, vl / 4);
1039+
const vint32m1_t iacc3 = __riscv_vwredsum_vs_i16m8_i32m1_tum(mask3, iacc4s, sumi, iaccz, vl * 4);
1040+
const vint32m1_t iacc3s = __riscv_vslideup_vx_i32m1(iacc3, iacc3, 1, vl / 4);
1041+
const vint32m1_t iacc2 = __riscv_vwredsum_vs_i16m8_i32m1_tum(mask2, iacc3s, sumi, iaccz, vl * 4);
1042+
const vint32m1_t iacc2s = __riscv_vslideup_vx_i32m1(iacc2, iacc2, 1, vl / 4);
1043+
const vint32m1_t iacc1 = __riscv_vwredsum_vs_i16m8_i32m1_tum(mask1, iacc2s, sumi, iaccz, vl * 4);
1044+
const vint32m1_t iacc1s = __riscv_vslideup_vx_i32m1(iacc1, iacc1, 1, vl / 4);
1045+
const vint32m1_t iacc0 = __riscv_vwredsum_vs_i16m8_i32m1_tum(mask0, iacc1s, sumi, iaccz, vl * 4);
1046+
const vfloat32m1_t facc = __riscv_vfcvt_f_x_v_f32m1(iacc0, vl / 4);
1047+
1048+
// vector version needs Zvfhmin extension
1049+
const float a_scale = GGML_FP16_TO_FP32(a_ptr[l].d);
1050+
const float b_scales[8] = {
1051+
GGML_FP16_TO_FP32(b_ptr[l].d[0]),
1052+
GGML_FP16_TO_FP32(b_ptr[l].d[1]),
1053+
GGML_FP16_TO_FP32(b_ptr[l].d[2]),
1054+
GGML_FP16_TO_FP32(b_ptr[l].d[3]),
1055+
GGML_FP16_TO_FP32(b_ptr[l].d[4]),
1056+
GGML_FP16_TO_FP32(b_ptr[l].d[5]),
1057+
GGML_FP16_TO_FP32(b_ptr[l].d[6]),
1058+
GGML_FP16_TO_FP32(b_ptr[l].d[7])
1059+
};
1060+
const vfloat32m1_t b_scales_vec = __riscv_vle32_v_f32m1(b_scales, vl / 4);
1061+
const vfloat32m1_t tmp1 = __riscv_vfmul_vf_f32m1(facc, a_scale, vl / 4);
1062+
const vfloat32m1_t tmp2 = __riscv_vfmul_vv_f32m1(tmp1, b_scales_vec, vl / 4);
1063+
sumf = __riscv_vfadd_vv_f32m1(sumf, tmp2, vl / 4);
1064+
}
1065+
__riscv_vse32_v_f32m1(s + x * ncols_interleaved, sumf, vl / 4);
1066+
}
1067+
return;
1068+
}
9941069
#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__)
9951070
{
9961071
float sumf[8];

0 commit comments

Comments
 (0)