Skip to content

Commit 78c78e2

Browse files
committed
ggml : Fix GCC rvv load alignment issue
1 parent c039415 commit 78c78e2

File tree

1 file changed

+45
-23
lines changed

1 file changed

+45
-23
lines changed

ggml/src/ggml-aarch64.c

Lines changed: 45 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1001,10 +1001,15 @@ void ggml_gemv_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void *
10011001

10021002
vfloat32m1_t sumf = __riscv_vfmv_v_f_f32m1(0.0, vl / 4);
10031003
for (int l = 0; l < nb; l++) {
1004-
const vint8m2_t lhs_0_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vlse64_v_i64m2((const int64_t *)&a_ptr[l].qs[0], 0, vl / 4));
1005-
const vint8m2_t lhs_1_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vlse64_v_i64m2((const int64_t *)&a_ptr[l].qs[8], 0, vl / 4));
1006-
const vint8m2_t lhs_2_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vlse64_v_i64m2((const int64_t *)&a_ptr[l].qs[16], 0, vl / 4));
1007-
const vint8m2_t lhs_3_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vlse64_v_i64m2((const int64_t *)&a_ptr[l].qs[24], 0, vl / 4));
1004+
const int64_t a0 = *(const int64_t *)&a_ptr[l].qs[0];
1005+
const int64_t a1 = *(const int64_t *)&a_ptr[l].qs[8];
1006+
const int64_t a2 = *(const int64_t *)&a_ptr[l].qs[16];
1007+
const int64_t a3 = *(const int64_t *)&a_ptr[l].qs[24];
1008+
__asm__ __volatile__("" ::: "memory"); // prevent gcc from emitting fused vlse64, violating alignment
1009+
const vint8m2_t lhs_0_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a0, vl / 4));
1010+
const vint8m2_t lhs_1_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a1, vl / 4));
1011+
const vint8m2_t lhs_2_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a2, vl / 4));
1012+
const vint8m2_t lhs_3_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a3, vl / 4));
10081013

10091014
const vint8m4_t rhs_raw_vec = __riscv_vle8_v_i8m4((const int8_t *)b_ptr[l].qs, vl * 4);
10101015
const vint8m4_t rhs_vec_lo = __riscv_vsra_vx_i8m4(__riscv_vsll_vx_i8m4(rhs_raw_vec, 4, vl * 4), 4, vl * 4);
@@ -3275,12 +3280,17 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void *
32753280
};
32763281
const vfloat32m1_t b_scales_vec = __riscv_vle32_v_f32m1(b_scales, vl / 4);
32773282

3283+
const int64_t A0 = *(const int64_t *)&a_ptr[l].qs[0];
3284+
const int64_t A4 = *(const int64_t *)&a_ptr[l].qs[32];
3285+
const int64_t A8 = *(const int64_t *)&a_ptr[l].qs[64];
3286+
const int64_t Ac = *(const int64_t *)&a_ptr[l].qs[96];
3287+
__asm__ __volatile__("" ::: "memory"); // prevent gcc from emitting fused vlse64, violating alignment
32783288
vint16m4_t sumi_l0;
32793289
{
3280-
const vint8m2_t lhs_0_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vlse64_v_i64m2((const int64_t *)&a_ptr[l].qs[0], 0, vl / 4));
3281-
const vint8m2_t lhs_1_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vlse64_v_i64m2((const int64_t *)&a_ptr[l].qs[32], 0, vl / 4));
3282-
const vint8m2_t lhs_2_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vlse64_v_i64m2((const int64_t *)&a_ptr[l].qs[64], 0, vl / 4));
3283-
const vint8m2_t lhs_3_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vlse64_v_i64m2((const int64_t *)&a_ptr[l].qs[96], 0, vl / 4));
3290+
const vint8m2_t lhs_0_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A0, vl / 4));
3291+
const vint8m2_t lhs_1_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A4, vl / 4));
3292+
const vint8m2_t lhs_2_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A8, vl / 4));
3293+
const vint8m2_t lhs_3_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(Ac, vl / 4));
32843294
const vint16m4_t sumi_lo_0 = __riscv_vwmul_vv_i16m4(rhs_vec_lo_0, lhs_0_8, vl * 2);
32853295
const vint16m4_t sumi_lo_1 = __riscv_vwmacc_vv_i16m4(sumi_lo_0, rhs_vec_lo_1, lhs_1_8, vl * 2);
32863296
const vint16m4_t sumi_hi_0 = __riscv_vwmacc_vv_i16m4(sumi_lo_1, rhs_vec_hi_0, lhs_2_8, vl * 2);
@@ -3307,14 +3317,18 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void *
33073317
const vfloat32m1_t tmp1 = __riscv_vfmul_vf_f32m1(facc, a_scales[0], vl / 4);
33083318
sumf0 = __riscv_vfmacc_vv_f32m1(sumf0, tmp1, b_scales_vec, vl / 4);
33093319
}
3310-
// __asm__ __volatile__("" ::: "memory");
33113320

3321+
const int64_t A1 = *(const int64_t *)&a_ptr[l].qs[8];
3322+
const int64_t A5 = *(const int64_t *)&a_ptr[l].qs[40];
3323+
const int64_t A9 = *(const int64_t *)&a_ptr[l].qs[72];
3324+
const int64_t Ad = *(const int64_t *)&a_ptr[l].qs[104];
3325+
__asm__ __volatile__("" ::: "memory"); // prevent gcc from emitting fused vlse64, violating alignment
33123326
vint16m4_t sumi_l1;
33133327
{
3314-
const vint8m2_t lhs_0_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vlse64_v_i64m2((const int64_t *)&a_ptr[l].qs[8], 0, vl / 4));
3315-
const vint8m2_t lhs_1_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vlse64_v_i64m2((const int64_t *)&a_ptr[l].qs[40], 0, vl / 4));
3316-
const vint8m2_t lhs_2_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vlse64_v_i64m2((const int64_t *)&a_ptr[l].qs[72], 0, vl / 4));
3317-
const vint8m2_t lhs_3_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vlse64_v_i64m2((const int64_t *)&a_ptr[l].qs[104], 0, vl / 4));
3328+
const vint8m2_t lhs_0_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A1, vl / 4));
3329+
const vint8m2_t lhs_1_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A5, vl / 4));
3330+
const vint8m2_t lhs_2_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A9, vl / 4));
3331+
const vint8m2_t lhs_3_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(Ad, vl / 4));
33183332
const vint16m4_t sumi_lo_0 = __riscv_vwmul_vv_i16m4(rhs_vec_lo_0, lhs_0_8, vl * 2);
33193333
const vint16m4_t sumi_lo_1 = __riscv_vwmacc_vv_i16m4(sumi_lo_0, rhs_vec_lo_1, lhs_1_8, vl * 2);
33203334
const vint16m4_t sumi_hi_0 = __riscv_vwmacc_vv_i16m4(sumi_lo_1, rhs_vec_hi_0, lhs_2_8, vl * 2);
@@ -3341,14 +3355,18 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void *
33413355
const vfloat32m1_t tmp1 = __riscv_vfmul_vf_f32m1(facc, a_scales[0], vl / 4);
33423356
sumf1 = __riscv_vfmacc_vv_f32m1(sumf1, tmp1, b_scales_vec, vl / 4);
33433357
}
3344-
// __asm__ __volatile__("" ::: "memory");
33453358

3359+
const int64_t A2 = *(const int64_t *)&a_ptr[l].qs[16];
3360+
const int64_t A6 = *(const int64_t *)&a_ptr[l].qs[48];
3361+
const int64_t Aa = *(const int64_t *)&a_ptr[l].qs[80];
3362+
const int64_t Ae = *(const int64_t *)&a_ptr[l].qs[112];
3363+
__asm__ __volatile__("" ::: "memory"); // prevent gcc from emitting fused vlse64, violating alignment
33463364
vint16m4_t sumi_l2;
33473365
{
3348-
const vint8m2_t lhs_0_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vlse64_v_i64m2((const int64_t *)&a_ptr[l].qs[16], 0, vl / 4));
3349-
const vint8m2_t lhs_1_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vlse64_v_i64m2((const int64_t *)&a_ptr[l].qs[48], 0, vl / 4));
3350-
const vint8m2_t lhs_2_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vlse64_v_i64m2((const int64_t *)&a_ptr[l].qs[80], 0, vl / 4));
3351-
const vint8m2_t lhs_3_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vlse64_v_i64m2((const int64_t *)&a_ptr[l].qs[112], 0, vl / 4));
3366+
const vint8m2_t lhs_0_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A2, vl / 4));
3367+
const vint8m2_t lhs_1_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A6, vl / 4));
3368+
const vint8m2_t lhs_2_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(Aa, vl / 4));
3369+
const vint8m2_t lhs_3_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(Ae, vl / 4));
33523370
const vint16m4_t sumi_lo_0 = __riscv_vwmul_vv_i16m4(rhs_vec_lo_0, lhs_0_8, vl * 2);
33533371
const vint16m4_t sumi_lo_1 = __riscv_vwmacc_vv_i16m4(sumi_lo_0, rhs_vec_lo_1, lhs_1_8, vl * 2);
33543372
const vint16m4_t sumi_hi_0 = __riscv_vwmacc_vv_i16m4(sumi_lo_1, rhs_vec_hi_0, lhs_2_8, vl * 2);
@@ -3375,14 +3393,18 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void *
33753393
const vfloat32m1_t tmp1 = __riscv_vfmul_vf_f32m1(facc, a_scales[2], vl / 4);
33763394
sumf2 = __riscv_vfmacc_vv_f32m1(sumf2, tmp1, b_scales_vec, vl / 4);
33773395
}
3378-
// __asm__ __volatile__("" ::: "memory");
33793396

3397+
const int64_t A3 = *(const int64_t *)&a_ptr[l].qs[24];
3398+
const int64_t A7 = *(const int64_t *)&a_ptr[l].qs[56];
3399+
const int64_t Ab = *(const int64_t *)&a_ptr[l].qs[88];
3400+
const int64_t Af = *(const int64_t *)&a_ptr[l].qs[120];
3401+
__asm__ __volatile__("" ::: "memory"); // prevent gcc from emitting fused vlse64, violating alignment
33803402
vint16m4_t sumi_l3;
33813403
{
3382-
const vint8m2_t lhs_0_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vlse64_v_i64m2((const int64_t *)&a_ptr[l].qs[24], 0, vl / 4));
3383-
const vint8m2_t lhs_1_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vlse64_v_i64m2((const int64_t *)&a_ptr[l].qs[56], 0, vl / 4));
3384-
const vint8m2_t lhs_2_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vlse64_v_i64m2((const int64_t *)&a_ptr[l].qs[88], 0, vl / 4));
3385-
const vint8m2_t lhs_3_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vlse64_v_i64m2((const int64_t *)&a_ptr[l].qs[120], 0, vl / 4));
3404+
const vint8m2_t lhs_0_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A3, vl / 4));
3405+
const vint8m2_t lhs_1_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A7, vl / 4));
3406+
const vint8m2_t lhs_2_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(Ab, vl / 4));
3407+
const vint8m2_t lhs_3_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(Af, vl / 4));
33863408
const vint16m4_t sumi_lo_0 = __riscv_vwmul_vv_i16m4(rhs_vec_lo_0, lhs_0_8, vl * 2);
33873409
const vint16m4_t sumi_lo_1 = __riscv_vwmacc_vv_i16m4(sumi_lo_0, rhs_vec_lo_1, lhs_1_8, vl * 2);
33883410
const vint16m4_t sumi_hi_0 = __riscv_vwmacc_vv_i16m4(sumi_lo_1, rhs_vec_hi_0, lhs_2_8, vl * 2);

0 commit comments

Comments
 (0)