Skip to content

Commit 5d64dd6

Browse files
q5_1 impl
1 parent 1c34635 commit 5d64dd6

File tree

1 file changed

+39
-29
lines changed

1 file changed

+39
-29
lines changed

ggml-cuda.cu

Lines changed: 39 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -124,8 +124,7 @@ static_assert(sizeof(block_q5_0) == sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5
124124
#define QR5_1 2
125125
#define QI5_1 (QK5_1 / (4 * QR5_1))
126126
typedef struct {
127-
half d; // delta
128-
half m; // min
127+
half2 dm; // dm.x = delta, dm.y = min
129128
uint8_t qh[4]; // 5-th bit of quants
130129
uint8_t qs[QK5_1 / 2]; // nibbles / quants
131130
} block_q5_1;
@@ -447,8 +446,8 @@ static __device__ __forceinline__ void dequantize_q5_0(const void * vx, const in
447446
static __device__ __forceinline__ void dequantize_q5_1(const void * vx, const int ib, const int iqs, dfloat2 & v){
448447
const block_q5_1 * x = (const block_q5_1 *) vx;
449448

450-
const dfloat d = x[ib].d;
451-
const dfloat m = x[ib].m;
449+
const dfloat d = x[ib].dm.x;
450+
const dfloat m = x[ib].dm.y;
452451

453452
uint32_t qh;
454453
memcpy(&qh, x[ib].qh, sizeof(qh));
@@ -1519,42 +1518,53 @@ static __device__ __forceinline__ float vec_dot_q5_0_q8_1_mul_mat(
15191518
y_qs[j * (2*WARP_SIZE) + kyqs + (QI8_1/2)], x_dm[index_bx].x, y_ds[j * (2*WARP_SIZE/QI8_1) + 2*k/QI8_1]);
15201519
}
15211520

1522-
static __device__ __forceinline__ float vec_dot_q5_1_q8_1(
1523-
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
1521+
static __device__ __forceinline__ float vec_dot_q5_1_q8_1_impl(
1522+
const int & qs, const int & qh, const int & ui0, const int & ui1, const half2 & dm5, const half2 & ds8) {
15241523

15251524
#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
1526-
const block_q5_1 * bq5_1 = (const block_q5_1 *) vbq;
1527-
1528-
const int qs = *((int *) &bq5_1->qs[sizeof(int) * (iqs + 0)]);
1529-
const int qh0 = bq5_1->qh[iqs/2 + 0] >> 4*(iqs%2);
1530-
const int qh1 = bq5_1->qh[iqs/2 + 2] >> 4*(iqs%2);
1531-
const int ui0 = *((int *) &bq8_1->qs[sizeof(int) * (iqs + 0)]);
1532-
const int ui1 = *((int *) &bq8_1->qs[sizeof(int) * (iqs + QI5_1)]);
1533-
1534-
const float d = __half2float(bq5_1->d) * __half2float(bq8_1->ds.x);
1535-
const float m = bq5_1->m;
1536-
const float s = bq8_1->ds.y;
1537-
1538-
int vi0 = (qs >> 0) & 0x0F0F0F0F; // lower 4 qs bits, still need qh0 as 5th bits
1539-
vi0 |= (qh0 << 4) & 0x00000010; // 1 -> 5
1540-
vi0 |= (qh0 << 11) & 0x00001000; // 2 -> 13
1541-
vi0 |= (qh0 << 18) & 0x00100000; // 3 -> 21
1542-
vi0 |= (qh0 << 25) & 0x10000000; // 4 -> 29
1525+
int vi0 = (qs >> 0) & 0x0F0F0F0F; // lower 4 qs bits, still need qh0 as 5th bits
1526+
vi0 |= (qh << 4) & 0x00000010; // 0 -> 4
1527+
vi0 |= (qh << 11) & 0x00001000; // 1 -> 12
1528+
vi0 |= (qh << 18) & 0x00100000; // 2 -> 20
1529+
vi0 |= (qh << 25) & 0x10000000; // 3 -> 28
15431530
int sumi = __dp4a(vi0, ui0, 0); // SIMD dot product of quantized values
15441531

1545-
int vi1 = (qs >> 4) & 0x0F0F0F0F; // upper 4 qs bits, still need qh1 as 5th bits
1546-
vi1 |= (qh1 << 4) & 0x00000010; // 1 -> 5
1547-
vi1 |= (qh1 << 11) & 0x00001000; // 2 -> 13
1548-
vi1 |= (qh1 << 18) & 0x00100000; // 3 -> 21
1549-
vi1 |= (qh1 << 25) & 0x10000000; // 4 -> 29
1532+
int vi1 = (qs >> 4) & 0x0F0F0F0F; // upper 4 qs bits, still need qh1 as 5th bits
1533+
vi1 |= (qh >> 12) & 0x00000010; // 16 -> 4
1534+
vi1 |= (qh >> 5) & 0x00001000; // 17 -> 12
1535+
vi1 |= (qh << 2) & 0x00100000; // 18 -> 20
1536+
vi1 |= (qh << 9) & 0x10000000; // 19 -> 28
15501537
sumi = __dp4a(vi1, ui1, sumi); // SIMD dot product of quantized values
15511538

1552-
return sumi*d + m*s / QI5_1; // scale sum by QI5_1 because there are QI5_1 threads working on this block
1539+
#ifdef GGML_CUDA_DMMV_F16
1540+
const half2 tmp = __hmul2(dm5, ds8);
1541+
const float d5d8 = __half2float(tmp.x);
1542+
const float m5s8 = __half2float(tmp.y);
1543+
#else
1544+
const float d5d8 = __half2float(dm5.x) * __half2float(ds8.x);
1545+
const float m5s8 = __half2float(dm5.y) * __half2float(ds8.y);
1546+
#endif // GGML_CUDA_DMMV_F16
1547+
1548+
return sumi*d5d8 + m5s8/QI5_1; // scale sum by QI5_1 because there are QI5_1 threads working on this block
1549+
15531550
#else
15541551
return 0.0f; // only to satisfy the compiler
15551552
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
15561553
}
15571554

1555+
static __device__ __forceinline__ float vec_dot_q5_1_q8_1(
1556+
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
1557+
1558+
const block_q5_1 * bq5_1 = (const block_q5_1 *) vbq;
1559+
1560+
const int qs = get_int_from_uint8_aligned(bq5_1->qs, iqs);
1561+
const int qh = get_int_from_uint8_aligned(bq5_1->qh, 0) >> (4 * iqs);
1562+
const int ui0 = get_int_from_int8_aligned(bq8_1->qs, iqs);
1563+
const int ui1 = get_int_from_int8_aligned(bq8_1->qs, iqs + QI5_1);
1564+
1565+
return vec_dot_q5_1_q8_1_impl(qs, qh, ui0, ui1, bq5_1->dm, bq8_1->ds);
1566+
}
1567+
15581568
static __device__ __forceinline__ float vec_dot_q8_0_q8_1(
15591569
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
15601570

0 commit comments

Comments
 (0)