Skip to content

Commit 274b0d7

Browse files
faster q5_0
1 parent 89d2b3e commit 274b0d7

File tree

1 file changed

+14
-17
lines changed

1 file changed

+14
-17
lines changed

ggml-cuda.cu

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1446,26 +1446,24 @@ static __device__ __forceinline__ float vec_dot_q4_1_q8_1_mul_mat(
14461446
}
14471447

14481448
static __device__ __forceinline__ float vec_dot_q5_0_q8_1_impl(
1449-
const int & qs, const int & qh0, const int & qh1, const int & ui0, const int & ui1, const half & d5, const half2 & ds8) {
1449+
const int & qs, const int & qh, const int & ui0, const int & ui1, const half & d5, const half2 & ds8) {
14501450

14511451
#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
1452-
int vi0 = (qs >> 0) & 0x0F0F0F0F; // lower 4 qs bits, still need qh0 as 5th bits
1453-
vi0 |= (qh0 << 4) & 0x00000010; // 0 -> 4
1454-
vi0 |= (qh0 << 11) & 0x00001000; // 1 -> 12
1455-
vi0 |= (qh0 << 18) & 0x00100000; // 2 -> 20
1456-
vi0 |= (qh0 << 25) & 0x10000000; // 3 -> 28
1457-
vi0 = __vsub4(vi0, 0x10101010); // subtract 16 from quantized values
1452+
int vi0 = (qs >> 0) & 0x0F0F0F0F; // lower 4 qs bits, still need qh as 5th bits
1453+
vi0 |= (qh << 4) & 0x00000010; // 0 -> 4
1454+
vi0 |= (qh << 11) & 0x00001000; // 1 -> 12
1455+
vi0 |= (qh << 18) & 0x00100000; // 2 -> 20
1456+
vi0 |= (qh << 25) & 0x10000000; // 3 -> 28
14581457
int sumi = __dp4a(vi0, ui0, 0); // SIMD dot product of quantized values
14591458

1460-
int vi1 = (qs >> 4) & 0x0F0F0F0F; // upper 4 qs bits, still need qh1 as 5th bits
1461-
vi1 |= (qh1 << 4) & 0x00000010; // 0 -> 4
1462-
vi1 |= (qh1 << 11) & 0x00001000; // 1 -> 12
1463-
vi1 |= (qh1 << 18) & 0x00100000; // 2 -> 20
1464-
vi1 |= (qh1 << 25) & 0x10000000; // 3 -> 28
1465-
vi1 = __vsub4(vi1, 0x10101010); // subtract 16 from quantized values
1459+
int vi1 = (qs >> 4) & 0x0F0F0F0F; // upper 4 qs bits, still need qh as 5th bits
1460+
vi1 |= (qh >> 12) & 0x00000010; // 16 -> 4
1461+
vi1 |= (qh >> 5) & 0x00001000; // 17 -> 12
1462+
vi1 |= (qh << 2) & 0x00100000; // 18 -> 20
1463+
vi1 |= (qh << 9) & 0x10000000; // 19 -> 28
14661464
sumi = __dp4a(vi1, ui1, sumi); // SIMD dot product of quantized values
14671465

1468-
return sumi*__half2float(d5)*__half2float(ds8.x);
1466+
return __half2float(d5) * (sumi*__half2float(ds8.x) - (16/QI5_0) * __half2float(ds8.y));
14691467
#else
14701468
return 0.0f; // only to satisfy the compiler
14711469
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
@@ -1477,12 +1475,11 @@ static __device__ __forceinline__ float vec_dot_q5_0_q8_1(
14771475
const block_q5_0 * bq5_0 = (const block_q5_0 *) vbq;
14781476

14791477
const int qs = get_int_from_uint8(bq5_0->qs, iqs);
1480-
const int qh0 = bq5_0->qh[iqs/2 + 0] >> 4*(iqs%2);
1481-
const int qh1 = bq5_0->qh[iqs/2 + 2] >> 4*(iqs%2);
1478+
const int qh = get_int_from_uint8(bq5_0->qh, 0) >> (4 * iqs);
14821479
const int ui0 = get_int_from_int8_aligned(bq8_1->qs, iqs);
14831480
const int ui1 = get_int_from_int8_aligned(bq8_1->qs, iqs + QI5_0);
14841481

1485-
return vec_dot_q5_0_q8_1_impl(qs, qh0, qh1, ui0, ui1, bq5_0->d, bq8_1->ds);
1482+
return vec_dot_q5_0_q8_1_impl(qs, qh, ui0, ui1, bq5_0->d, bq8_1->ds);
14861483
}
14871484

14881485
static __device__ __forceinline__ float vec_dot_q5_1_q8_1(

0 commit comments

Comments
 (0)