@@ -991,6 +991,81 @@ void ggml_gemv_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void *
991
991
}
992
992
}
993
993
return ;
994
+ #elif defined(__riscv_v_intrinsic )
995
+ if (__riscv_vlenb () >= QK4_0 ) {
996
+ const size_t vl = QK4_0 ;
997
+ const vuint8m1_t lhs_idx_m1 = __riscv_vand_vx_u8m1 (__riscv_vid_v_u8m1 (vl ), 7 , vl );
998
+ const vuint8m2_t lhs_idx_m2 = __riscv_vcreate_v_u8m1_u8m2 (lhs_idx_m1 , lhs_idx_m1 );
999
+ const vuint8m2_t lhs_idx_m2_hi = __riscv_vadd_vx_u8m2 (lhs_idx_m2 , 8 , vl );
1000
+ const vuint8m4_t lhs_idx_m4 = __riscv_vcreate_v_u8m2_u8m4 (lhs_idx_m2 , lhs_idx_m2_hi );
1001
+ const vbool2_t mask0 = __riscv_vreinterpret_v_u16m1_b2 (__riscv_vreinterpret_v_u64m1_u16m1 (__riscv_vmv_v_x_u64m1 (0x00000000000000FFull , vl / 8 )));
1002
+ const vbool2_t mask1 = __riscv_vreinterpret_v_u16m1_b2 (__riscv_vreinterpret_v_u64m1_u16m1 (__riscv_vmv_v_x_u64m1 (0x000000000000FF00ull , vl / 8 )));
1003
+ const vbool2_t mask2 = __riscv_vreinterpret_v_u16m1_b2 (__riscv_vreinterpret_v_u64m1_u16m1 (__riscv_vmv_v_x_u64m1 (0x0000000000FF0000ull , vl / 8 )));
1004
+ const vbool2_t mask3 = __riscv_vreinterpret_v_u16m1_b2 (__riscv_vreinterpret_v_u64m1_u16m1 (__riscv_vmv_v_x_u64m1 (0x00000000FF000000ull , vl / 8 )));
1005
+ const vbool2_t mask4 = __riscv_vreinterpret_v_u16m1_b2 (__riscv_vreinterpret_v_u64m1_u16m1 (__riscv_vmv_v_x_u64m1 (0x000000FF00000000ull , vl / 8 )));
1006
+ const vbool2_t mask5 = __riscv_vreinterpret_v_u16m1_b2 (__riscv_vreinterpret_v_u64m1_u16m1 (__riscv_vmv_v_x_u64m1 (0x0000FF0000000000ull , vl / 8 )));
1007
+ const vbool2_t mask6 = __riscv_vreinterpret_v_u16m1_b2 (__riscv_vreinterpret_v_u64m1_u16m1 (__riscv_vmv_v_x_u64m1 (0x00FF000000000000ull , vl / 8 )));
1008
+ const vbool2_t mask7 = __riscv_vreinterpret_v_u16m1_b2 (__riscv_vreinterpret_v_u64m1_u16m1 (__riscv_vmv_v_x_u64m1 (0xFF00000000000000ull , vl / 8 )));
1009
+
1010
+ const block_q8_0 * a_ptr = (const block_q8_0 * ) vy ;
1011
+ for (int x = 0 ; x < nc / ncols_interleaved ; x ++ ) {
1012
+ const block_q4_0x8 * b_ptr = (const block_q4_0x8 * ) vx + (x * nb );
1013
+
1014
+ vfloat32m1_t sumf = __riscv_vfmv_v_f_f32m1 (0.0 , vl / 4 );
1015
+ for (int l = 0 ; l < nb ; l ++ ) {
1016
+ const vint8m1_t lhs_raw_vec = __riscv_vle8_v_i8m1 (a_ptr [l ].qs , vl );
1017
+ const vint8m4_t lhs_raw_vec_lo = __riscv_vset_v_i8m1_i8m4 (__riscv_vundefined_i8m4 (), 0 , lhs_raw_vec );
1018
+ const vint8m4_t lhs_raw_vec_hi = __riscv_vset_v_i8m1_i8m4 (__riscv_vundefined_i8m4 (), 0 , __riscv_vslidedown_vx_i8m1 (lhs_raw_vec , 16 , vl ));
1019
+ const vint8m4_t lhs_vec_lo = __riscv_vrgather_vv_i8m4 (lhs_raw_vec_lo , lhs_idx_m4 , vl * 4 );
1020
+ const vint8m4_t lhs_vec_hi = __riscv_vrgather_vv_i8m4 (lhs_raw_vec_hi , lhs_idx_m4 , vl * 4 );
1021
+
1022
+ const vint8m4_t rhs_raw_vec = __riscv_vle8_v_i8m4 ((const int8_t * )b_ptr [l ].qs , vl * 4 );
1023
+ const vint8m4_t rhs_vec_lo = __riscv_vsra_vx_i8m4 (__riscv_vsll_vx_i8m4 (rhs_raw_vec , 4 , vl * 4 ), 4 , vl * 4 );
1024
+ const vint8m4_t rhs_vec_hi = __riscv_vsra_vx_i8m4 (rhs_raw_vec , 4 , vl * 4 );
1025
+
1026
+ const vint16m8_t sumi_lo = __riscv_vwmul_vv_i16m8 (rhs_vec_lo , lhs_vec_lo , vl * 4 );
1027
+ const vint16m8_t sumi_hi = __riscv_vwmul_vv_i16m8 (rhs_vec_hi , lhs_vec_hi , vl * 4 );
1028
+ const vint16m8_t sumi = __riscv_vadd_vv_i16m8 (sumi_lo , sumi_hi , vl * 4 );
1029
+
1030
+ const vint32m1_t iaccz = __riscv_vmv_v_x_i32m1 (0 , vl / 4 );
1031
+ const vint32m1_t iacc7 = __riscv_vwredsum_vs_i16m8_i32m1_m (mask7 , sumi , iaccz , vl * 4 );
1032
+ const vint32m1_t iacc7s = __riscv_vslideup_vx_i32m1 (iacc7 , iacc7 , 1 , vl / 4 );
1033
+ const vint32m1_t iacc6 = __riscv_vwredsum_vs_i16m8_i32m1_tum (mask6 , iacc7s , sumi , iaccz , vl * 4 );
1034
+ const vint32m1_t iacc6s = __riscv_vslideup_vx_i32m1 (iacc6 , iacc6 , 1 , vl / 4 );
1035
+ const vint32m1_t iacc5 = __riscv_vwredsum_vs_i16m8_i32m1_tum (mask5 , iacc6s , sumi , iaccz , vl * 4 );
1036
+ const vint32m1_t iacc5s = __riscv_vslideup_vx_i32m1 (iacc5 , iacc5 , 1 , vl / 4 );
1037
+ const vint32m1_t iacc4 = __riscv_vwredsum_vs_i16m8_i32m1_tum (mask4 , iacc5s , sumi , iaccz , vl * 4 );
1038
+ const vint32m1_t iacc4s = __riscv_vslideup_vx_i32m1 (iacc4 , iacc4 , 1 , vl / 4 );
1039
+ const vint32m1_t iacc3 = __riscv_vwredsum_vs_i16m8_i32m1_tum (mask3 , iacc4s , sumi , iaccz , vl * 4 );
1040
+ const vint32m1_t iacc3s = __riscv_vslideup_vx_i32m1 (iacc3 , iacc3 , 1 , vl / 4 );
1041
+ const vint32m1_t iacc2 = __riscv_vwredsum_vs_i16m8_i32m1_tum (mask2 , iacc3s , sumi , iaccz , vl * 4 );
1042
+ const vint32m1_t iacc2s = __riscv_vslideup_vx_i32m1 (iacc2 , iacc2 , 1 , vl / 4 );
1043
+ const vint32m1_t iacc1 = __riscv_vwredsum_vs_i16m8_i32m1_tum (mask1 , iacc2s , sumi , iaccz , vl * 4 );
1044
+ const vint32m1_t iacc1s = __riscv_vslideup_vx_i32m1 (iacc1 , iacc1 , 1 , vl / 4 );
1045
+ const vint32m1_t iacc0 = __riscv_vwredsum_vs_i16m8_i32m1_tum (mask0 , iacc1s , sumi , iaccz , vl * 4 );
1046
+ const vfloat32m1_t facc = __riscv_vfcvt_f_x_v_f32m1 (iacc0 , vl / 4 );
1047
+
1048
+ // vector version needs Zvfhmin extension
1049
+ const float a_scale = GGML_FP16_TO_FP32 (a_ptr [l ].d );
1050
+ const float b_scales [8 ] = {
1051
+ GGML_FP16_TO_FP32 (b_ptr [l ].d [0 ]),
1052
+ GGML_FP16_TO_FP32 (b_ptr [l ].d [1 ]),
1053
+ GGML_FP16_TO_FP32 (b_ptr [l ].d [2 ]),
1054
+ GGML_FP16_TO_FP32 (b_ptr [l ].d [3 ]),
1055
+ GGML_FP16_TO_FP32 (b_ptr [l ].d [4 ]),
1056
+ GGML_FP16_TO_FP32 (b_ptr [l ].d [5 ]),
1057
+ GGML_FP16_TO_FP32 (b_ptr [l ].d [6 ]),
1058
+ GGML_FP16_TO_FP32 (b_ptr [l ].d [7 ])
1059
+ };
1060
+ const vfloat32m1_t b_scales_vec = __riscv_vle32_v_f32m1 (b_scales , vl / 4 );
1061
+ const vfloat32m1_t tmp1 = __riscv_vfmul_vf_f32m1 (facc , a_scale , vl / 4 );
1062
+ const vfloat32m1_t tmp2 = __riscv_vfmul_vv_f32m1 (tmp1 , b_scales_vec , vl / 4 );
1063
+ sumf = __riscv_vfadd_vv_f32m1 (sumf , tmp2 , vl / 4 );
1064
+ }
1065
+ __riscv_vse32_v_f32m1 (s + x * ncols_interleaved , sumf , vl / 4 );
1066
+ }
1067
+ return ;
1068
+ }
994
1069
#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__)
995
1070
{
996
1071
float sumf [8 ];
0 commit comments