@@ -994,56 +994,57 @@ void ggml_gemv_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void *
994
994
#elif defined(__riscv_v_intrinsic )
995
995
if (__riscv_vlenb () >= QK4_0 ) {
996
996
const size_t vl = QK4_0 ;
997
- const vuint8m1_t lhs_idx_m1 = __riscv_vand_vx_u8m1 (__riscv_vid_v_u8m1 (vl ), 7 , vl );
998
- const vuint8m2_t lhs_idx_m2 = __riscv_vcreate_v_u8m1_u8m2 (lhs_idx_m1 , lhs_idx_m1 );
999
- const vuint8m2_t lhs_idx_m2_hi = __riscv_vadd_vx_u8m2 (lhs_idx_m2 , 8 , vl );
1000
- const vuint8m4_t lhs_idx_m4 = __riscv_vcreate_v_u8m2_u8m4 (lhs_idx_m2 , lhs_idx_m2_hi );
1001
- const vbool2_t mask0 = __riscv_vreinterpret_v_u16m1_b2 (__riscv_vreinterpret_v_u64m1_u16m1 (__riscv_vmv_v_x_u64m1 (0x00000000000000FFull , vl / 8 )));
1002
- const vbool2_t mask1 = __riscv_vreinterpret_v_u16m1_b2 (__riscv_vreinterpret_v_u64m1_u16m1 (__riscv_vmv_v_x_u64m1 (0x000000000000FF00ull , vl / 8 )));
1003
- const vbool2_t mask2 = __riscv_vreinterpret_v_u16m1_b2 (__riscv_vreinterpret_v_u64m1_u16m1 (__riscv_vmv_v_x_u64m1 (0x0000000000FF0000ull , vl / 8 )));
1004
- const vbool2_t mask3 = __riscv_vreinterpret_v_u16m1_b2 (__riscv_vreinterpret_v_u64m1_u16m1 (__riscv_vmv_v_x_u64m1 (0x00000000FF000000ull , vl / 8 )));
1005
- const vbool2_t mask4 = __riscv_vreinterpret_v_u16m1_b2 (__riscv_vreinterpret_v_u64m1_u16m1 (__riscv_vmv_v_x_u64m1 (0x000000FF00000000ull , vl / 8 )));
1006
- const vbool2_t mask5 = __riscv_vreinterpret_v_u16m1_b2 (__riscv_vreinterpret_v_u64m1_u16m1 (__riscv_vmv_v_x_u64m1 (0x0000FF0000000000ull , vl / 8 )));
1007
- const vbool2_t mask6 = __riscv_vreinterpret_v_u16m1_b2 (__riscv_vreinterpret_v_u64m1_u16m1 (__riscv_vmv_v_x_u64m1 (0x00FF000000000000ull , vl / 8 )));
1008
- const vbool2_t mask7 = __riscv_vreinterpret_v_u16m1_b2 (__riscv_vreinterpret_v_u64m1_u16m1 (__riscv_vmv_v_x_u64m1 (0xFF00000000000000ull , vl / 8 )));
997
+ const uint8_t mask_bytes [] = {0x00 , 0x00 , 0x00 , 0x00 , 0x00 , 0x00 , 0x00 , 0xFF };
998
+ const vbool4_t mask7 = __riscv_vlm_v_b4 (mask_bytes , vl * 2 );
999
+ const vint32m1_t iaccz = __riscv_vmv_v_x_i32m1 (0 , vl / 4 );
1009
1000
1010
1001
const block_q8_0 * a_ptr = (const block_q8_0 * ) vy ;
1011
1002
for (int x = 0 ; x < nc / ncols_interleaved ; x ++ ) {
1012
1003
const block_q4_0x8 * b_ptr = (const block_q4_0x8 * ) vx + (x * nb );
1013
1004
1014
1005
vfloat32m1_t sumf = __riscv_vfmv_v_f_f32m1 (0.0 , vl / 4 );
1015
1006
for (int l = 0 ; l < nb ; l ++ ) {
1016
- const vint8m1_t lhs_raw_vec = __riscv_vle8_v_i8m1 (a_ptr [l ].qs , vl );
1017
- const vint8m4_t lhs_raw_vec_lo = __riscv_vset_v_i8m1_i8m4 (__riscv_vundefined_i8m4 (), 0 , lhs_raw_vec );
1018
- const vint8m4_t lhs_raw_vec_hi = __riscv_vset_v_i8m1_i8m4 (__riscv_vundefined_i8m4 (), 0 , __riscv_vslidedown_vx_i8m1 (lhs_raw_vec , 16 , vl ));
1019
- const vint8m4_t lhs_vec_lo = __riscv_vrgather_vv_i8m4 (lhs_raw_vec_lo , lhs_idx_m4 , vl * 4 );
1020
- const vint8m4_t lhs_vec_hi = __riscv_vrgather_vv_i8m4 (lhs_raw_vec_hi , lhs_idx_m4 , vl * 4 );
1007
+ const vint8m1_t lhs_0_4 = __riscv_vreinterpret_v_i64m1_i8m1 (__riscv_vmv_v_x_i64m1 (* (int64_t * )& a_ptr [l ].qs [0 ], vl / 8 ));
1008
+ const vint8m1_t lhs_1_4 = __riscv_vreinterpret_v_i64m1_i8m1 (__riscv_vmv_v_x_i64m1 (* (int64_t * )& a_ptr [l ].qs [8 ], vl / 8 ));
1009
+ const vint8m1_t lhs_2_4 = __riscv_vreinterpret_v_i64m1_i8m1 (__riscv_vmv_v_x_i64m1 (* (int64_t * )& a_ptr [l ].qs [16 ], vl / 8 ));
1010
+ const vint8m1_t lhs_3_4 = __riscv_vreinterpret_v_i64m1_i8m1 (__riscv_vmv_v_x_i64m1 (* (int64_t * )& a_ptr [l ].qs [24 ], vl / 8 ));
1011
+ const vint8m4_t lhs_vec_lo = __riscv_vcreate_v_i8m1_i8m4 (lhs_0_4 , lhs_0_4 , lhs_1_4 , lhs_1_4 );
1012
+ const vint8m4_t lhs_vec_hi = __riscv_vcreate_v_i8m1_i8m4 (lhs_2_4 , lhs_2_4 , lhs_3_4 , lhs_3_4 );
1021
1013
1022
1014
const vint8m4_t rhs_raw_vec = __riscv_vle8_v_i8m4 ((const int8_t * )b_ptr [l ].qs , vl * 4 );
1023
1015
const vint8m4_t rhs_vec_lo = __riscv_vsra_vx_i8m4 (__riscv_vsll_vx_i8m4 (rhs_raw_vec , 4 , vl * 4 ), 4 , vl * 4 );
1024
1016
const vint8m4_t rhs_vec_hi = __riscv_vsra_vx_i8m4 (rhs_raw_vec , 4 , vl * 4 );
1025
1017
1026
1018
const vint16m8_t sumi_lo = __riscv_vwmul_vv_i16m8 (rhs_vec_lo , lhs_vec_lo , vl * 4 );
1027
1019
const vint16m8_t sumi_hi = __riscv_vwmul_vv_i16m8 (rhs_vec_hi , lhs_vec_hi , vl * 4 );
1028
- const vint16m8_t sumi = __riscv_vadd_vv_i16m8 (sumi_lo , sumi_hi , vl * 4 );
1029
-
1030
- const vint32m1_t iaccz = __riscv_vmv_v_x_i32m1 (0 , vl / 4 );
1031
- const vint32m1_t iacc7 = __riscv_vwredsum_vs_i16m8_i32m1_m (mask7 , sumi , iaccz , vl * 4 );
1032
- const vint32m1_t iacc7s = __riscv_vslideup_vx_i32m1 (iacc7 , iacc7 , 1 , vl / 4 );
1033
- const vint32m1_t iacc6 = __riscv_vwredsum_vs_i16m8_i32m1_tum (mask6 , iacc7s , sumi , iaccz , vl * 4 );
1034
- const vint32m1_t iacc6s = __riscv_vslideup_vx_i32m1 (iacc6 , iacc6 , 1 , vl / 4 );
1035
- const vint32m1_t iacc5 = __riscv_vwredsum_vs_i16m8_i32m1_tum (mask5 , iacc6s , sumi , iaccz , vl * 4 );
1036
- const vint32m1_t iacc5s = __riscv_vslideup_vx_i32m1 (iacc5 , iacc5 , 1 , vl / 4 );
1037
- const vint32m1_t iacc4 = __riscv_vwredsum_vs_i16m8_i32m1_tum (mask4 , iacc5s , sumi , iaccz , vl * 4 );
1038
- const vint32m1_t iacc4s = __riscv_vslideup_vx_i32m1 (iacc4 , iacc4 , 1 , vl / 4 );
1039
- const vint32m1_t iacc3 = __riscv_vwredsum_vs_i16m8_i32m1_tum (mask3 , iacc4s , sumi , iaccz , vl * 4 );
1040
- const vint32m1_t iacc3s = __riscv_vslideup_vx_i32m1 (iacc3 , iacc3 , 1 , vl / 4 );
1041
- const vint32m1_t iacc2 = __riscv_vwredsum_vs_i16m8_i32m1_tum (mask2 , iacc3s , sumi , iaccz , vl * 4 );
1042
- const vint32m1_t iacc2s = __riscv_vslideup_vx_i32m1 (iacc2 , iacc2 , 1 , vl / 4 );
1043
- const vint32m1_t iacc1 = __riscv_vwredsum_vs_i16m8_i32m1_tum (mask1 , iacc2s , sumi , iaccz , vl * 4 );
1044
- const vint32m1_t iacc1s = __riscv_vslideup_vx_i32m1 (iacc1 , iacc1 , 1 , vl / 4 );
1045
- const vint32m1_t iacc0 = __riscv_vwredsum_vs_i16m8_i32m1_tum (mask0 , iacc1s , sumi , iaccz , vl * 4 );
1046
- const vfloat32m1_t facc = __riscv_vfcvt_f_x_v_f32m1 (iacc0 , vl / 4 );
1020
+ const vint16m8_t sumi2 = __riscv_vadd_vv_i16m8 (sumi_lo , sumi_hi , vl * 4 );
1021
+ const vint16m4_t sumi2_lo = __riscv_vget_v_i16m8_i16m4 (sumi2 , 0 );
1022
+ const vint16m4_t sumi2_hi = __riscv_vget_v_i16m8_i16m4 (sumi2 , 1 );
1023
+ const vint16m4_t sumi = __riscv_vadd_vv_i16m4 (sumi2_lo , sumi2_hi , vl * 2 );
1024
+
1025
+ const vint32m1_t iacc7 = __riscv_vwredsum_vs_i16m4_i32m1_m (mask7 , sumi , iaccz , vl * 2 );
1026
+ sumi = __riscv_vslideup_vx_i16m4 (sumi , sumi , 8 , vl * 2 );
1027
+ const vint32m1_t iacc6 = __riscv_vwredsum_vs_i16m4_i32m1_m (mask7 , sumi , iaccz , vl * 2 );
1028
+ sumi = __riscv_vslideup_vx_i16m4 (sumi , sumi , 8 , vl * 2 );
1029
+ const vint32m1_t iacc5 = __riscv_vwredsum_vs_i16m4_i32m1_m (mask7 , sumi , iaccz , vl * 2 );
1030
+ sumi = __riscv_vslideup_vx_i16m4 (sumi , sumi , 8 , vl * 2 );
1031
+ const vint32m1_t iacc4 = __riscv_vwredsum_vs_i16m4_i32m1_m (mask7 , sumi , iaccz , vl * 2 );
1032
+ sumi = __riscv_vslideup_vx_i16m4 (sumi , sumi , 8 , vl * 2 );
1033
+ const vint32m1_t iacc3 = __riscv_vwredsum_vs_i16m4_i32m1_m (mask7 , sumi , iaccz , vl * 2 );
1034
+ sumi = __riscv_vslideup_vx_i16m4 (sumi , sumi , 8 , vl * 2 );
1035
+ const vint32m1_t iacc2 = __riscv_vwredsum_vs_i16m4_i32m1_m (mask7 , sumi , iaccz , vl * 2 );
1036
+ sumi = __riscv_vslideup_vx_i16m4 (sumi , sumi , 8 , vl * 2 );
1037
+ const vint32m1_t iacc1 = __riscv_vwredsum_vs_i16m4_i32m1_m (mask7 , sumi , iaccz , vl * 2 );
1038
+ sumi = __riscv_vslideup_vx_i16m4 (sumi , sumi , 8 , vl * 2 );
1039
+ const vint32m1_t iacc0 = __riscv_vwredsum_vs_i16m4_i32m1_m (mask7 , sumi , iaccz , vl * 2 );
1040
+ const vint32m1_t iacc6s = __riscv_vslide1up_vx_i32m1 (iacc7 , __riscv_vmv_x_s_i32m1_i32 (iacc6 ), vl / 4 );
1041
+ const vint32m1_t iacc5s = __riscv_vslide1up_vx_i32m1 (iacc6s , __riscv_vmv_x_s_i32m1_i32 (iacc5 ), vl / 4 );
1042
+ const vint32m1_t iacc4s = __riscv_vslide1up_vx_i32m1 (iacc5s , __riscv_vmv_x_s_i32m1_i32 (iacc4 ), vl / 4 );
1043
+ const vint32m1_t iacc3s = __riscv_vslide1up_vx_i32m1 (iacc4s , __riscv_vmv_x_s_i32m1_i32 (iacc3 ), vl / 4 );
1044
+ const vint32m1_t iacc2s = __riscv_vslide1up_vx_i32m1 (iacc3s , __riscv_vmv_x_s_i32m1_i32 (iacc2 ), vl / 4 );
1045
+ const vint32m1_t iacc1s = __riscv_vslide1up_vx_i32m1 (iacc2s , __riscv_vmv_x_s_i32m1_i32 (iacc1 ), vl / 4 );
1046
+ const vint32m1_t iacc0s = __riscv_vslide1up_vx_i32m1 (iacc1s , __riscv_vmv_x_s_i32m1_i32 (iacc0 ), vl / 4 );
1047
+ const vfloat32m1_t facc = __riscv_vfcvt_f_x_v_f32m1 (iacc0s , vl / 4 );
1047
1048
1048
1049
// vector version needs Zvfhmin extension
1049
1050
const float a_scale = GGML_FP16_TO_FP32 (a_ptr [l ].d );
@@ -3246,6 +3247,50 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void *
3246
3247
}
3247
3248
}
3248
3249
}
3250
+ return ;
3251
+ }
3252
+ #elif defined(__riscv_v_intrinsic )
3253
+ if (__riscv_vlenb () >= QK4_0 ) {
3254
+ const size_t vl = QK4_0 ;
3255
+ const uint8_t mask_bytes [] = {0x00 , 0x00 , 0x00 , 0x00 , 0x00 , 0x00 , 0x00 , 0xFF };
3256
+ const vbool4_t mask7 = __riscv_vlm_v_b4 (mask_bytes , vl * 2 );
3257
+
3258
+ for (int y = 0 ; y < nr / 4 ; y ++ ) {
3259
+ const block_q8_0x4 * a_ptr = (const block_q8_0x4 * ) vy + (y * nb );
3260
+ for (int x = 0 ; x < nc / ncols_interleaved ; x ++ ) {
3261
+ const block_q4_0x8 * b_ptr = (const block_q4_0x8 * ) vx + (x * nb );
3262
+ // for (int m = 0; m < 4; m++) {
3263
+ // for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0;
3264
+ // }
3265
+ vfloat32m1_t sumf0 = __riscv_vfmv_v_f_f32m1 (0.0 , vl / 4 );
3266
+ vfloat32m1_t sumf1 = __riscv_vfmv_v_f_f32m1 (0.0 , vl / 4 );
3267
+ vfloat32m1_t sumf2 = __riscv_vfmv_v_f_f32m1 (0.0 , vl / 4 );
3268
+ vfloat32m1_t sumf3 = __riscv_vfmv_v_f_f32m1 (0.0 , vl / 4 );
3269
+ for (int l = 0 ; l < nb ; l ++ ) {
3270
+ const vint8m4_t rhs_raw_vec = __riscv_vle8_v_i8m4 ((const int8_t * )b_ptr [l ].qs , vl * 4 );
3271
+ const vint8m4_t rhs_vec_lo = __riscv_vsra_vx_i8m4 (__riscv_vsll_vx_i8m4 (rhs_raw_vec , 4 , vl * 4 ), 4 , vl * 4 );
3272
+ const vint8m4_t rhs_vec_hi = __riscv_vsra_vx_i8m4 (rhs_raw_vec , 4 , vl * 4 );
3273
+
3274
+ {
3275
+ const vint8m1_t lhs_0_4 = __riscv_vreinterpret_v_i64m1_i8m1 (__riscv_vmv_v_x_i64m1 (* (int64_t * )& a_ptr [l ].qs [0 ], vl / 8 ));
3276
+ const vint8m1_t lhs_1_4 = __riscv_vreinterpret_v_i64m1_i8m1 (__riscv_vmv_v_x_i64m1 (* (int64_t * )& a_ptr [l ].qs [32 ], vl / 8 ));
3277
+ const vint8m1_t lhs_2_4 = __riscv_vreinterpret_v_i64m1_i8m1 (__riscv_vmv_v_x_i64m1 (* (int64_t * )& a_ptr [l ].qs [64 ], vl / 8 ));
3278
+ const vint8m1_t lhs_3_4 = __riscv_vreinterpret_v_i64m1_i8m1 (__riscv_vmv_v_x_i64m1 (* (int64_t * )& a_ptr [l ].qs [96 ], vl / 8 ));
3279
+ const vint8m4_t lhs_vec_lo = __riscv_vcreate_v_i8m1_i8m4 (lhs_0_4 , lhs_0_4 , lhs_1_4 , lhs_1_4 );
3280
+ const vint8m4_t lhs_vec_hi = __riscv_vcreate_v_i8m1_i8m4 (lhs_2_4 , lhs_2_4 , lhs_3_4 , lhs_3_4 );
3281
+ }
3282
+ }
3283
+ // for (int m = 0; m < 4; m++) {
3284
+ // for (int j = 0; j < ncols_interleaved; j++)
3285
+ // s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j];
3286
+ // }
3287
+ __riscv_vse32_v_f32m1 (& s [(y * 4 + 0 ) * bs + x * ncols_interleaved ], sumf0 , vl / 4 );
3288
+ __riscv_vse32_v_f32m1 (& s [(y * 4 + 1 ) * bs + x * ncols_interleaved ], sumf1 , vl / 4 );
3289
+ __riscv_vse32_v_f32m1 (& s [(y * 4 + 2 ) * bs + x * ncols_interleaved ], sumf2 , vl / 4 );
3290
+ __riscv_vse32_v_f32m1 (& s [(y * 4 + 3 ) * bs + x * ncols_interleaved ], sumf3 , vl / 4 );
3291
+ }
3292
+ }
3293
+
3249
3294
return ;
3250
3295
}
3251
3296
#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__)
0 commit comments