Skip to content

Commit 393eae0

Browse files
small change
1 parent 859f0b6 commit 393eae0

File tree

1 file changed

+6
-5
lines changed

1 file changed

+6
-5
lines changed

ggml-cuda.cu

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1405,7 +1405,8 @@ static __device__ __forceinline__ float vec_dot_q2_K_q8_1(
14051405

14061406
const int bq8_offset = 4 * (iqs/8);
14071407

1408-
float sumf = 0;
1408+
float sumf_d = 0;
1409+
float sumf_m = 0;
14091410

14101411
const float d = bq2_K->d;
14111412
const float dmin = bq2_K->dmin;
@@ -1414,19 +1415,19 @@ static __device__ __forceinline__ float vec_dot_q2_K_q8_1(
14141415

14151416
for (int i = 0; i < 4; ++i) {
14161417
const int sc = bq2_K->scales[iqs - iqs%8 + (iqs%8) / 4 + 2*i];
1417-
const float dl = d * (sc & 0xF);
1418-
const float ml = dmin * (sc >> 4);
14191418

14201419
const int vii = (vi >> (2*i)) & 0x03030303;
14211420

14221421
const block_q8_1 * bq8i = bq8_1 + bq8_offset + i;
14231422
const float d8 = bq8i->d;
14241423
const int qs8 = *((int*) &bq8i->qs[4*(iqs%8)]);
14251424

1426-
sumf += d8*(dl*__dp4a(vii, qs8, 0) - ml*__dp4a(0x01010101, qs8, 0));
1425+
sumf_d += d8 * __dp4a(vii, qs8, 0) * (sc & 0xF);
1426+
sumf_m += d8 * __dp4a(0x01010101, qs8, 0) * (sc >> 4);
14271427
}
14281428

1429-
return sumf;
1429+
1430+
return d*sumf_d - dmin*sumf_m;
14301431
// #else
14311432
// return 0.0f; // only to satisfy the compiler
14321433
// #endif // __CUDA_ARCH__ >= 600

0 commit comments

Comments
 (0)