Skip to content

Fix LoongArch compile error with 128-bit SIMD #11701

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 6, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
169 changes: 91 additions & 78 deletions ggml/src/ggml-cpu/ggml-cpu-quants.c
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,90 @@ static const uint64_t table_b2b_0[1 << 8] = { B8(00, 10) }; // ( b) << 4
static const uint64_t table_b2b_1[1 << 8] = { B8(10, 00) }; // (!b) << 4
#endif

#if defined(__loongarch_sx)

static __m128i lsx_packs_w(__m128i a, __m128i b) {
__m128i tmp, tmp1;
tmp = __lsx_vsat_w(a, 15);
tmp1 = __lsx_vsat_w(b, 15);
return __lsx_vpickev_h(tmp1, tmp);
}

static __m128i lsx_packs_h(__m128i a, __m128i b) {
__m128i tmp, tmp1;
tmp = __lsx_vsat_h(a, 7);
tmp1 = __lsx_vsat_h(b, 7);
return __lsx_vpickev_b(tmp1, tmp);
}

static __m128i lsx_packus_h(__m128i a, __m128i b) {
__m128i tmp, tmp1;
tmp = __lsx_vsat_hu(a, 7);
tmp1 = __lsx_vsat_hu(b, 7);
return __lsx_vpickev_b(tmp1, tmp);
}

static __m128i lsx_maddubs_h(__m128i a, __m128i b) {
__m128i tmp1, tmp2;
tmp1 = __lsx_vmulwev_h_b(a, b);
tmp2 = __lsx_vmulwod_h_b(a, b);
return __lsx_vsadd_h(tmp1, tmp2);
}

static __m128i lsx_madd_h(__m128i a, __m128i b) {
__m128i tmp1, tmp2;
tmp1 = __lsx_vmulwev_w_h(a, b);
tmp2 = __lsx_vmulwod_w_h(a, b);
return __lsx_vadd_w(tmp1, tmp2);
}

static __m128i lsx_set_w(int32_t a, int32_t b, int32_t c, int32_t d) {
v4i32 __ret = {d, c, b, a};
return (__m128i)__ret;
}

static __m128i lsx_shuffle_b(__m128i a, __m128i b) {
__m128i mask_f, zero, tmp0, tmp2, mask;
int f = 0x8f;
mask_f = __lsx_vreplgr2vr_b(f);
zero = __lsx_vldi(0);
tmp0 = __lsx_vand_v(b, mask_f); // get mask with low 4 bit and sign bits
tmp0 = __lsx_vori_b(tmp0, 0x10); // make each mask or with 0x10 prepare for positive
mask = __lsx_vsle_b(zero, tmp0); // if mask >= 0, set mask
tmp2 = __lsx_vand_v(tmp0, mask); // maskout the in2 < ones
return __lsx_vshuf_b(a, zero, tmp2);
}

static __m128i lsx_hadd_h(__m128i a, __m128i b) {
__m128i tmp1 = __lsx_vpickev_h(b, a);
__m128i tmp2 = __lsx_vpickod_h(b, a);
return __lsx_vadd_h(tmp1, tmp2);
}

static __m128i lsx_hadd_w(__m128i a, __m128i b) {
__m128i tmp1 = __lsx_vpickev_w(b, a);
__m128i tmp2 = __lsx_vpickod_w(b, a);
return __lsx_vadd_w(tmp1, tmp2);
}

static __m128 lsx_hadd_s(__m128 a, __m128 b) {
__m128 tmp1 = (__m128)__lsx_vpickev_w((__m128i)b, (__m128i)a);
__m128 tmp2 = (__m128)__lsx_vpickod_w((__m128i)b, (__m128i)a);

return __lsx_vfadd_s(tmp1, tmp2);
}

static inline float hsum_float_4x4(const __m128 a, const __m128 b, const __m128 c, const __m128 d) {
__m128 res_0 =lsx_hadd_s(a, b);
__m128 res_1 =lsx_hadd_s(c, d);
__m128 res =lsx_hadd_s(res_0, res_1);
res =lsx_hadd_s(res, res);
res =lsx_hadd_s(res, res);

return ((v4f32)res)[0];
}
#endif

#if defined(__loongarch_asx)

#ifdef __clang__
Expand Down Expand Up @@ -395,11 +479,6 @@ static __m256i lasx_set_w(int e7, int e6, int e5, int e4, int e3, int e2, int e1
return (__m256i)__ret;
}

static __m128i lsx_set_w(int32_t a, int32_t b, int32_t c, int32_t d) {
v4i32 __ret = {d, c, b, a};
return (__m128i)__ret;
}

static __m256i lasx_set_d(int64_t a, int64_t b, int64_t c, int64_t d) {
v4i64 __ret = {d, c, b, a};
return (__m256i)__ret;
Expand All @@ -409,18 +488,6 @@ static __m256i lasx_insertf128( __m128i x, __m128i y) {
return lasx_set_q(x, y);
}

static __m128i lsx_shuffle_b(__m128i a, __m128i b) {
__m128i mask_f, zero, tmp0, tmp2, mask;
int f = 0x8f;
mask_f = __lsx_vreplgr2vr_b(f);
zero = __lsx_vldi(0);
tmp0 = __lsx_vand_v(b, mask_f); // get mask with low 4 bit and sign bits
tmp0 = __lsx_vori_b(tmp0, 0x10); // make each mask or with 0x10 prepare for positive
mask = __lsx_vsle_b(zero, tmp0); // if mask >= 0, set mask
tmp2 = __lsx_vand_v(tmp0, mask); // maskout the in2 < ones
return __lsx_vshuf_b(a, zero, tmp2);
}

static __m256i lasx_shuffle_b(__m256i a, __m256i b) {
__m256i mask_f, zero, tmp0, tmp2, mask;
int f = 0x8f;
Expand Down Expand Up @@ -482,25 +549,6 @@ static __m128 lasx_extractf128( __m256 a, int pos) {
return ret;
}

static __m128i lsx_hadd_h(__m128i a, __m128i b) {
__m128i tmp1 = __lsx_vpickev_h(b, a);
__m128i tmp2 = __lsx_vpickod_h(b, a);
return __lsx_vadd_h(tmp1, tmp2);
}

static __m128i lsx_hadd_w(__m128i a, __m128i b) {
__m128i tmp1 = __lsx_vpickev_w(b, a);
__m128i tmp2 = __lsx_vpickod_w(b, a);
return __lsx_vadd_w(tmp1, tmp2);
}

static __m128 lsx_hadd_s(__m128 a, __m128 b) {
__m128 tmp1 = (__m128)__lsx_vpickev_w((__m128i)b, (__m128i)a);
__m128 tmp2 = (__m128)__lsx_vpickod_w((__m128i)b, (__m128i)a);

return __lsx_vfadd_s(tmp1, tmp2);
}

static __m256i lasx_maddubs_h(__m256i a, __m256i b) {
__m256i tmp1, tmp2;
tmp1 = __lasx_xvmulwev_h_b(a, b);
Expand Down Expand Up @@ -529,42 +577,6 @@ static __m256i lasx_packs_h(__m256i a, __m256i b) {
return __lasx_xvpickev_b(tmp1, tmp);
}

static __m128i lsx_packs_w(__m128i a, __m128i b) {
__m128i tmp, tmp1;
tmp = __lsx_vsat_w(a, 15);
tmp1 = __lsx_vsat_w(b, 15);
return __lsx_vpickev_h(tmp1, tmp);
}

static __m128i lsx_packs_h(__m128i a, __m128i b) {
__m128i tmp, tmp1;
tmp = __lsx_vsat_h(a, 7);
tmp1 = __lsx_vsat_h(b, 7);
return __lsx_vpickev_b(tmp1, tmp);
}

static __m128i lsx_packus_h(__m128i a, __m128i b) {
__m128i tmp, tmp1;
tmp = __lsx_vsat_hu(a, 7);
tmp1 = __lsx_vsat_hu(b, 7);
return __lsx_vpickev_b(tmp1, tmp);
}


static __m128i lsx_maddubs_h(__m128i a, __m128i b) {
__m128i tmp1, tmp2;
tmp1 = __lsx_vmulwev_h_b(a, b);
tmp2 = __lsx_vmulwod_h_b(a, b);
return __lsx_vsadd_h(tmp1, tmp2);
}

static __m128i lsx_madd_h(__m128i a, __m128i b) {
__m128i tmp1, tmp2;
tmp1 = __lsx_vmulwev_w_h(a, b);
tmp2 = __lsx_vmulwod_w_h(a, b);
return __lsx_vadd_w(tmp1, tmp2);
}

// multiply int8_t, add results pairwise twice
static inline __m128i mul_sum_i8_pairs(const __m128i x, const __m128i y) {
// Get absolute values of x vectors
Expand Down Expand Up @@ -2232,21 +2244,22 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r
}

sumf = hsum_float_8(acc);

#elif defined(__loongarch_sx)
// set constants
const __m128i low_mask = __lsx_vreplgr2vr_b(0xF);
const __m128i off = __lsx_vreplgr2vr_b(8);

// Initialize accumulator with zeros
__m128 acc_0 = __lsx_vldi(0);
__m128 acc_1 = __lsx_vldi(0);
__m128 acc_2 = __lsx_vldi(0);
__m128 acc_3 = __lsx_vldi(0);
__m128 acc_0 = (__m128)__lsx_vldi(0);
__m128 acc_1 = (__m128)__lsx_vldi(0);
__m128 acc_2 = (__m128)__lsx_vldi(0);
__m128 acc_3 = (__m128)__lsx_vldi(0);

for (; ib + 1 < nb; ib += 2) {

// Compute combined scale for the block 0 and 1
const __m128 d_0_1 = __lsx_vreplgr2vr_w( GGML_FP16_TO_FP32(x[ib].d) * GGML_FP16_TO_FP32(y[ib].d) );
const __m128 d_0_1 = (__m128)__lsx_vreplgr2vr_w( GGML_FP16_TO_FP32(x[ib].d) * GGML_FP16_TO_FP32(y[ib].d) );

const __m128i tmp_0_1 = __lsx_vld((const __m128i *)x[ib].qs, 0);

Expand All @@ -2264,7 +2277,7 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r
//_mm_prefetch(&y[ib] + 2 * sizeof(block_q8_0), _MM_HINT_T0);

// Compute combined scale for the block 2 and 3
const __m128 d_2_3 = __lsx_vreplgr2vr_w( GGML_FP16_TO_FP32(x[ib + 1].d) * GGML_FP16_TO_FP32(y[ib + 1].d) );
const __m128 d_2_3 = (__m128)__lsx_vreplgr2vr_w( GGML_FP16_TO_FP32(x[ib + 1].d) * GGML_FP16_TO_FP32(y[ib + 1].d) );

const __m128i tmp_2_3 = __lsx_vld((const __m128i *)x[ib + 1].qs, 0);

Expand Down
Loading