Skip to content

Commit cc6a0f5

Browse files
committed
ggml : fix iq4_nl dot product with odd number of blocks
1 parent b328344 commit cc6a0f5

File tree

2 files changed

+69
-54
lines changed

2 files changed

+69
-54
lines changed

ggml/src/ggml-quants.c

Lines changed: 42 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -11745,6 +11745,9 @@ void ggml_vec_dot_iq4_nl_q8_0(int n, float * restrict s, size_t bs, const void *
1174511745

1174611746
const int nb = n / QK4_NL;
1174711747

11748+
int ib = 0;
11749+
float sumf = 0;
11750+
1174811751
#if defined __ARM_NEON
1174911752
const int8x16_t values = vld1q_s8(kvalues_iq4nl);
1175011753
const uint8x16_t m4b = vdupq_n_u8(0x0f);
@@ -11753,16 +11756,14 @@ void ggml_vec_dot_iq4_nl_q8_0(int n, float * restrict s, size_t bs, const void *
1175311756
int8x16x4_t q8b;
1175411757
int32x4_t prod_1, prod_2;
1175511758

11756-
float sumf = 0;
11757-
11758-
for (int ib = 0; ib < nb; ib += 2) {
11759+
for (; ib + 1 < nb; ib += 2) {
1175911760

11760-
q4bits.val[0] = vld1q_u8(x[ib+0].qs);
11761-
q4bits.val[1] = vld1q_u8(x[ib+1].qs);
11762-
q8b.val[0] = vld1q_s8(y[ib+0].qs);
11763-
q8b.val[1] = vld1q_s8(y[ib+0].qs + 16);
11764-
q8b.val[2] = vld1q_s8(y[ib+1].qs);
11765-
q8b.val[3] = vld1q_s8(y[ib+1].qs + 16);
11761+
q4bits.val[0] = vld1q_u8(x[ib + 0].qs);
11762+
q4bits.val[1] = vld1q_u8(x[ib + 1].qs);
11763+
q8b.val[0] = vld1q_s8(y[ib + 0].qs);
11764+
q8b.val[1] = vld1q_s8(y[ib + 0].qs + 16);
11765+
q8b.val[2] = vld1q_s8(y[ib + 1].qs);
11766+
q8b.val[3] = vld1q_s8(y[ib + 1].qs + 16);
1176611767

1176711768
q4b.val[0] = ggml_vqtbl1q_s8(values, vandq_u8 (q4bits.val[0], m4b));
1176811769
q4b.val[1] = ggml_vqtbl1q_s8(values, vshrq_n_u8(q4bits.val[0], 4));
@@ -11773,12 +11774,10 @@ void ggml_vec_dot_iq4_nl_q8_0(int n, float * restrict s, size_t bs, const void *
1177311774
prod_2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q4b.val[2], q8b.val[2]), q4b.val[3], q8b.val[3]);
1177411775

1177511776
sumf +=
11776-
GGML_FP16_TO_FP32(x[ib+0].d) * GGML_FP16_TO_FP32(y[ib+0].d) * vaddvq_s32(prod_1) +
11777-
GGML_FP16_TO_FP32(x[ib+1].d) * GGML_FP16_TO_FP32(y[ib+1].d) * vaddvq_s32(prod_2);
11777+
GGML_FP16_TO_FP32(x[ib+0].d) * GGML_FP16_TO_FP32(y[ib + 0].d) * vaddvq_s32(prod_1) +
11778+
GGML_FP16_TO_FP32(x[ib+1].d) * GGML_FP16_TO_FP32(y[ib + 1].d) * vaddvq_s32(prod_2);
1177811779
}
1177911780

11780-
*s = sumf;
11781-
1178211781
#elif defined __AVX2__
1178311782

1178411783
const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_iq4nl);
@@ -11787,11 +11786,11 @@ void ggml_vec_dot_iq4_nl_q8_0(int n, float * restrict s, size_t bs, const void *
1178711786

1178811787
__m256 accum1 = _mm256_setzero_ps();
1178911788
__m256 accum2 = _mm256_setzero_ps();
11790-
for (int ib = 0; ib < nb; ib += 2) {
11791-
const __m128i q4bits_1 = _mm_loadu_si128((const __m128i*)x[0].qs);
11792-
const __m128i q4bits_2 = _mm_loadu_si128((const __m128i*)x[1].qs);
11793-
const __m256i q8b_1 = _mm256_loadu_si256((const __m256i *)y[0].qs);
11794-
const __m256i q8b_2 = _mm256_loadu_si256((const __m256i *)y[1].qs);
11789+
for (; ib + 1 < nb; ib += 2) {
11790+
const __m128i q4bits_1 = _mm_loadu_si128((const __m128i*)x[ib + 0].qs);
11791+
const __m128i q4bits_2 = _mm_loadu_si128((const __m128i*)x[ib + 1].qs);
11792+
const __m256i q8b_1 = _mm256_loadu_si256((const __m256i *)y[ib + 0].qs);
11793+
const __m256i q8b_2 = _mm256_loadu_si256((const __m256i *)y[ib + 1].qs);
1179511794
const __m256i q4b_1 = MM256_SET_M128I(_mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4b)),
1179611795
_mm_shuffle_epi8(values128, _mm_and_si128(q4bits_1, m4b)));
1179711796
const __m256i q4b_2 = MM256_SET_M128I(_mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_2, 4), m4b)),
@@ -11800,16 +11799,13 @@ void ggml_vec_dot_iq4_nl_q8_0(int n, float * restrict s, size_t bs, const void *
1180011799
const __m256i p16_2 = mul_add_epi8(q4b_2, q8b_2);
1180111800
const __m256i p_1 = _mm256_madd_epi16(p16_1, mone);
1180211801
const __m256i p_2 = _mm256_madd_epi16(p16_2, mone);
11803-
accum1 = _mm256_fmadd_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(y[0].d)*GGML_FP16_TO_FP32(x[0].d)),
11802+
accum1 = _mm256_fmadd_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(y[ib + 0].d)*GGML_FP16_TO_FP32(x[ib + 0].d)),
1180411803
_mm256_cvtepi32_ps(p_1), accum1);
11805-
accum2 = _mm256_fmadd_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(y[1].d)*GGML_FP16_TO_FP32(x[1].d)),
11804+
accum2 = _mm256_fmadd_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(y[ib + 1].d)*GGML_FP16_TO_FP32(x[ib + 1].d)),
1180611805
_mm256_cvtepi32_ps(p_2), accum2);
11807-
11808-
y += 2;
11809-
x += 2;
1181011806
}
1181111807

11812-
*s = hsum_float_8(_mm256_add_ps(accum1, accum2));
11808+
sumf = hsum_float_8(_mm256_add_ps(accum1, accum2));
1181311809

1181411810
#elif defined __AVX__
1181511811
const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_iq4nl);
@@ -11818,13 +11814,13 @@ void ggml_vec_dot_iq4_nl_q8_0(int n, float * restrict s, size_t bs, const void *
1181811814

1181911815
__m256 accum1 = _mm256_setzero_ps();
1182011816
__m256 accum2 = _mm256_setzero_ps();
11821-
for (int ib = 0; ib < nb; ib += 2) {
11822-
const __m128i q4bits_1 = _mm_loadu_si128((const __m128i *)x[0].qs);
11823-
const __m128i q4bits_2 = _mm_loadu_si128((const __m128i *)x[1].qs);
11824-
const __m128i q8b_1_0 = _mm_loadu_si128((const __m128i *)y[0].qs);
11825-
const __m128i q8b_1_1 = _mm_loadu_si128((const __m128i *)y[0].qs + 1);
11826-
const __m128i q8b_2_0 = _mm_loadu_si128((const __m128i *)y[1].qs);
11827-
const __m128i q8b_2_1 = _mm_loadu_si128((const __m128i *)y[1].qs + 1);
11817+
for (; ib + 1 < nb; ib += 2) {
11818+
const __m128i q4bits_1 = _mm_loadu_si128((const __m128i *)x[ib + 0].qs);
11819+
const __m128i q4bits_2 = _mm_loadu_si128((const __m128i *)x[ib + 1].qs);
11820+
const __m128i q8b_1_0 = _mm_loadu_si128((const __m128i *)y[ib + 0].qs);
11821+
const __m128i q8b_1_1 = _mm_loadu_si128((const __m128i *)y[ib + 0].qs + 1);
11822+
const __m128i q8b_2_0 = _mm_loadu_si128((const __m128i *)y[ib + 1].qs);
11823+
const __m128i q8b_2_1 = _mm_loadu_si128((const __m128i *)y[ib + 1].qs + 1);
1182811824

1182911825
const __m128i q4b_1_0 = _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_1, m4b));
1183011826
const __m128i q4b_1_1 = _mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4b));
@@ -11838,16 +11834,13 @@ void ggml_vec_dot_iq4_nl_q8_0(int n, float * restrict s, size_t bs, const void *
1183811834
const __m128i p_1_1 = _mm_madd_epi16(p16_1_1, mone);
1183911835
const __m128i p_2_0 = _mm_madd_epi16(p16_2_0, mone);
1184011836
const __m128i p_2_1 = _mm_madd_epi16(p16_2_1, mone);
11841-
accum1 = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(y[0].d)*GGML_FP16_TO_FP32(x[0].d)),
11837+
accum1 = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(y[ib + 0].d)*GGML_FP16_TO_FP32(x[ib + 0].d)),
1184211838
_mm256_cvtepi32_ps(MM256_SET_M128I(p_1_1, p_1_0))), accum1);
11843-
accum2 = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(y[1].d)*GGML_FP16_TO_FP32(x[1].d)),
11839+
accum2 = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(y[ib + 1].d)*GGML_FP16_TO_FP32(x[ib + 1].d)),
1184411840
_mm256_cvtepi32_ps(MM256_SET_M128I(p_2_1, p_2_0))), accum2);
11845-
11846-
y += 2;
11847-
x += 2;
1184811841
}
1184911842

11850-
*s = hsum_float_8(_mm256_add_ps(accum1, accum2));
11843+
sumf = hsum_float_8(_mm256_add_ps(accum1, accum2));
1185111844

1185211845
#elif defined(__POWER9_VECTOR__)
1185311846
const vector signed char lowMask = vec_splats((signed char)0xF);
@@ -11860,7 +11853,7 @@ void ggml_vec_dot_iq4_nl_q8_0(int n, float * restrict s, size_t bs, const void *
1186011853
const vector signed char values = vec_xl( 0, kvalues_iq4nl);
1186111854

1186211855
#pragma GCC unroll 4
11863-
for (int ib = 0; ib < nb; ++ib) {
11856+
for (; ib < nb; ++ib) {
1186411857
__builtin_prefetch(x[ib].qs, 0, 1);
1186511858
__builtin_prefetch(y[ib].qs, 0, 1);
1186611859

@@ -11897,7 +11890,7 @@ void ggml_vec_dot_iq4_nl_q8_0(int n, float * restrict s, size_t bs, const void *
1189711890
vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4));
1189811891
vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8));
1189911892

11900-
*s = vec_extract(vsumf0, 0);
11893+
sumf = vec_extract(vsumf0, 0);
1190111894

1190211895
#elif defined (__loongarch_asx)
1190311896

@@ -11907,11 +11900,11 @@ void ggml_vec_dot_iq4_nl_q8_0(int n, float * restrict s, size_t bs, const void *
1190711900

1190811901
__m256 accum1 = (__m256)__lasx_xvldi(0);
1190911902
__m256 accum2 = (__m256)__lasx_xvldi(0);
11910-
for (int ib = 0; ib < nb; ib += 2) {
11911-
const __m128i q4bits_1 = __lsx_vld((const __m128i*)x[0].qs, 0);
11912-
const __m128i q4bits_2 = __lsx_vld((const __m128i*)x[1].qs, 0);
11913-
const __m256i q8b_1 = __lasx_xvld((const __m256i *)y[0].qs, 0);
11914-
const __m256i q8b_2 = __lasx_xvld((const __m256i *)y[1].qs, 0);
11903+
for (; ib + 1 < nb; ib += 2) {
11904+
const __m128i q4bits_1 = __lsx_vld((const __m128i*)x[ib + 0].qs, 0);
11905+
const __m128i q4bits_2 = __lsx_vld((const __m128i*)x[ib + 1].qs, 0);
11906+
const __m256i q8b_1 = __lasx_xvld((const __m256i *)y[ib + 0].qs, 0);
11907+
const __m256i q8b_2 = __lasx_xvld((const __m256i *)y[ib + 1].qs, 0);
1191511908
const __m256i q4b_1 = lasx_insertf128(lsx_shuffle_b(values128, __lsx_vand_v(__lsx_vsrli_h(q4bits_1, 4), m4b)),
1191611909
lsx_shuffle_b(values128, __lsx_vand_v(q4bits_1, m4b)));
1191711910
const __m256i q4b_2 = lasx_insertf128(lsx_shuffle_b(values128, __lsx_vand_v(__lsx_vsrli_h(q4bits_2, 4), m4b)),
@@ -11920,20 +11913,16 @@ void ggml_vec_dot_iq4_nl_q8_0(int n, float * restrict s, size_t bs, const void *
1192011913
const __m256i p16_2 = mul_add_epi8(q4b_2, q8b_2);
1192111914
const __m256i p_1 = lasx_madd_h(p16_1, mone);
1192211915
const __m256i p_2 = lasx_madd_h(p16_2, mone);
11923-
accum1 = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(GGML_FP16_TO_FP32(y[0].d)*GGML_FP16_TO_FP32(x[0].d)),
11916+
accum1 = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(GGML_FP16_TO_FP32(y[ib + 0].d)*GGML_FP16_TO_FP32(x[ib + 0].d)),
1192411917
__lasx_xvffint_s_w(p_1), accum1);
11925-
accum2 = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(GGML_FP16_TO_FP32(y[1].d)*GGML_FP16_TO_FP32(x[1].d)),
11918+
accum2 = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(GGML_FP16_TO_FP32(y[ib + 1].d)*GGML_FP16_TO_FP32(x[ib + 1].d)),
1192611919
__lasx_xvffint_s_w(p_2), accum2);
11927-
11928-
y += 2;
11929-
x += 2;
1193011920
}
1193111921

11932-
*s = hsum_float_8(__lasx_xvfadd_s(accum1, accum2));
11922+
sumf = hsum_float_8(__lasx_xvfadd_s(accum1, accum2));
1193311923

11934-
#else
11935-
float sumf = 0;
11936-
for (int ib = 0; ib < nb; ++ib) {
11924+
#endif
11925+
for (; ib < nb; ++ib) {
1193711926
const float d = GGML_FP16_TO_FP32(y[ib].d)*GGML_FP16_TO_FP32(x[ib].d);
1193811927
int sumi1 = 0, sumi2 = 0;
1193911928
for (int j = 0; j < QK4_NL/2; ++j) {
@@ -11943,7 +11932,6 @@ void ggml_vec_dot_iq4_nl_q8_0(int n, float * restrict s, size_t bs, const void *
1194311932
sumf += d * (sumi1 + sumi2);
1194411933
}
1194511934
*s = sumf;
11946-
#endif
1194711935
}
1194811936

1194911937
void ggml_vec_dot_iq4_xs_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {

tests/test-backend-ops.cpp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,8 +79,16 @@ static void init_tensor_uniform(ggml_tensor * tensor, float min = -1.0f, float m
7979
im = nullptr;
8080
}
8181
}
82+
8283
ggml_quantize_chunk(tensor->type, data.data(), dataq.data(), 0, size/tensor->ne[0], tensor->ne[0], im);
8384
GGML_ASSERT(ggml_validate_row_data(tensor->type, dataq.data(), dataq.size()));
85+
// TODO: other cases
86+
//#pragma omp parallel for
87+
//for (int i = 0; i < tensor->ne[1]; i++) {
88+
// ggml_quantize_chunk(tensor->type, data.data(), dataq.data(),
89+
// i * tensor->ne[0], 1, tensor->ne[0], im);
90+
//}
91+
8492
ggml_backend_tensor_set(tensor, dataq.data(), 0, dataq.size());
8593
} else if (tensor->type == GGML_TYPE_I8 || tensor->type == GGML_TYPE_I16 || tensor->type == GGML_TYPE_I32) {
8694
// This is going to create some weird integers though.
@@ -2220,6 +2228,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
22202228
test_cases.emplace_back(new test_rms_norm(GGML_TYPE_F32, {64, 10, 10, 10}, eps));
22212229
}
22222230

2231+
#if 1
22232232
for (ggml_type type_a : base_types) {
22242233
for (ggml_type type_b : {GGML_TYPE_F32, GGML_TYPE_F16}) {
22252234
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, { 1, 1}, {1, 1}));
@@ -2239,6 +2248,24 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
22392248
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10, 10}, {2, 2}));
22402249
}
22412250
}
2251+
#else
2252+
// m = a rows
2253+
// n = b rows
2254+
// k = cols
2255+
std::uniform_int_distribution<> dist_m(1, 128);
2256+
std::uniform_int_distribution<> dist_n(16, 128);
2257+
std::uniform_int_distribution<> dist_k(1, 16);
2258+
for (int i = 0; i < 1000; i++) {
2259+
for (ggml_type type_a : all_types) {
2260+
for (ggml_type type_b : {GGML_TYPE_F32}) {
2261+
int m = dist_m(rng);
2262+
int n = dist_n(rng);
2263+
int k = dist_k(rng) * ggml_blck_size(type_a);
2264+
test_cases.emplace_back(new test_mul_mat(type_a, type_b, m, n, k, { 1, 1}, {1, 1}));
2265+
}
2266+
}
2267+
}
2268+
#endif
22422269

22432270
for (ggml_type type_a : other_types) {
22442271
for (ggml_type type_b : {GGML_TYPE_F32}) {

0 commit comments

Comments
 (0)