Skip to content

Commit 253eab8

Browse files
committed
ggml : poc for normalizing weights for better quantization (metal)
1 parent b532a69 commit 253eab8

File tree

5 files changed

+295
-139
lines changed

5 files changed

+295
-139
lines changed

ggml-cuda.cu

Lines changed: 38 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -204,23 +204,31 @@ typedef void (*ggml_cuda_op_t)(
204204
// QR = QK / number of values before dequantization
205205
// QI = number of 32 bit integers before dequantization
206206

207+
#define Q4_0DM (1.0f/8.0f)
208+
#define Q4_0D(x) (((x)*Q4_0DM) / 127.0f)
209+
207210
#define QK4_0 32
208211
#define QR4_0 2
209212
#define QI4_0 (QK4_0 / (4 * QR4_0))
210213
typedef struct {
211-
half d; // delta
214+
int8_t d; // delta
212215
uint8_t qs[QK4_0 / 2]; // nibbles / quants
213216
} block_q4_0;
214-
static_assert(sizeof(block_q4_0) == sizeof(ggml_fp16_t) + QK4_0 / 2, "wrong q4_0 block size/padding");
217+
static_assert(sizeof(block_q4_0) == sizeof(int8_t) + QK4_0 / 2, "wrong q4_0 block size/padding");
218+
219+
#define Q4_1DM (2.0f/15.0f)
220+
#define Q4_1MM (2.0f )
221+
#define Q4_1D(x) ( (((x) & 0xFF)*Q4_1DM) / 255.0f)
222+
#define Q4_1M(x) (-1.0f + (((x) >> 8)*Q4_1MM) / 255.0f)
215223

216224
#define QK4_1 32
217225
#define QR4_1 2
218226
#define QI4_1 (QK4_1 / (4 * QR4_1))
219227
typedef struct {
220-
half2 dm; // dm.x = delta, dm.y = min
221-
uint8_t qs[QK4_1 / 2]; // nibbles / quants
228+
uint16_t dm; // 8-bit delta + 8-bit min (can be adjusted easily)
229+
uint8_t qs[QK4_1 / 2]; // nibbles / quants
222230
} block_q4_1;
223-
static_assert(sizeof(block_q4_1) == sizeof(ggml_fp16_t) * 2 + QK4_1 / 2, "wrong q4_1 block size/padding");
231+
static_assert(sizeof(block_q4_1) == sizeof(uint16_t) + QK4_1 / 2, "wrong q4_1 block size/padding");
224232

225233
#define QK5_0 32
226234
#define QR5_0 2
@@ -232,15 +240,20 @@ typedef struct {
232240
} block_q5_0;
233241
static_assert(sizeof(block_q5_0) == sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_0 / 2, "wrong q5_0 block size/padding");
234242

243+
#define Q5_1DM (2.0f/31.0f)
244+
#define Q5_1MM (2.0f )
245+
#define Q5_1D(x) ( (((x) & 0x0F)*Q5_1DM) / 15.0f)
246+
#define Q5_1M(x) (-1.0f + (((x) >> 4)*Q5_1MM) / 15.0f)
247+
235248
#define QK5_1 32
236249
#define QR5_1 2
237250
#define QI5_1 (QK5_1 / (4 * QR5_1))
238251
typedef struct {
239-
half2 dm; // dm.x = delta, dm.y = min
252+
uint8_t dm; // 4-bit delta + 4-bit min
240253
uint8_t qh[4]; // 5-th bit of quants
241254
uint8_t qs[QK5_1 / 2]; // nibbles / quants
242255
} block_q5_1;
243-
static_assert(sizeof(block_q5_1) == 2 * sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_1 / 2, "wrong q5_1 block size/padding");
256+
static_assert(sizeof(block_q5_1) == sizeof(uint8_t) + sizeof(uint32_t) + QK5_1 / 2, "wrong q5_1 block size/padding");
244257

245258
#define QK8_0 32
246259
#define QR8_0 1
@@ -506,7 +519,7 @@ static __global__ void rms_norm_f32(const float * x, float * dst, const int ncol
506519
static __device__ __forceinline__ void dequantize_q4_0(const void * vx, const int ib, const int iqs, dfloat2 & v){
507520
const block_q4_0 * x = (const block_q4_0 *) vx;
508521

509-
const dfloat d = x[ib].d;
522+
const dfloat d = Q4_0D(x[ib].d);
510523

511524
const int vui = x[ib].qs[iqs];
512525

@@ -525,8 +538,8 @@ static __device__ __forceinline__ void dequantize_q4_0(const void * vx, const in
525538
static __device__ __forceinline__ void dequantize_q4_1(const void * vx, const int ib, const int iqs, dfloat2 & v){
526539
const block_q4_1 * x = (const block_q4_1 *) vx;
527540

528-
const dfloat d = __low2half(x[ib].dm);
529-
const dfloat m = __high2half(x[ib].dm);
541+
const dfloat d = Q4_1D(x[ib].dm);
542+
const dfloat m = Q4_1M(x[ib].dm);
530543

531544
const int vui = x[ib].qs[iqs];
532545

@@ -568,8 +581,8 @@ static __device__ __forceinline__ void dequantize_q5_0(const void * vx, const in
568581
static __device__ __forceinline__ void dequantize_q5_1(const void * vx, const int ib, const int iqs, dfloat2 & v){
569582
const block_q5_1 * x = (const block_q5_1 *) vx;
570583

571-
const dfloat d = __low2half(x[ib].dm);
572-
const dfloat m = __high2half(x[ib].dm);
584+
const dfloat d = Q5_1D(x[ib].dm);
585+
const dfloat m = Q5_1M(x[ib].dm);
573586

574587
uint32_t qh;
575588
memcpy(&qh, x[ib].qh, sizeof(qh));
@@ -2041,7 +2054,7 @@ static __device__ __forceinline__ float vec_dot_q4_0_q8_1(
20412054
u[2*i+1] = get_int_from_int8_aligned(bq8_1->qs, iqs + i + QI4_0);
20422055
}
20432056

2044-
return vec_dot_q4_0_q8_1_impl<VDR_Q4_0_Q8_1_MMVQ>(v, u, bq4_0->d, bq8_1->ds);
2057+
return vec_dot_q4_0_q8_1_impl<VDR_Q4_0_Q8_1_MMVQ>(v, u, Q4_0D(bq4_0->d), bq8_1->ds);
20452058
}
20462059

20472060
template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q4_0(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
@@ -2135,7 +2148,12 @@ static __device__ __forceinline__ float vec_dot_q4_1_q8_1(
21352148
u[2*i+1] = get_int_from_int8_aligned(bq8_1->qs, iqs + i + QI4_1);
21362149
}
21372150

2138-
return vec_dot_q4_1_q8_1_impl<VDR_Q4_1_Q8_1_MMVQ>(v, u, bq4_1->dm, bq8_1->ds);
2151+
const float d = Q4_1D(bq4_1->dm);
2152+
const float m = Q4_1M(bq4_1->dm);
2153+
2154+
const float2 dm = {d, m};
2155+
2156+
return vec_dot_q4_1_q8_1_impl<VDR_Q4_1_Q8_1_MMVQ>(v, u, dm, bq8_1->ds);
21392157
}
21402158

21412159
template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q4_1(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
@@ -2341,7 +2359,12 @@ static __device__ __forceinline__ float vec_dot_q5_1_q8_1(
23412359
u[2*i+1] = get_int_from_int8_aligned(bq8_1->qs, iqs + i + QI5_1);
23422360
}
23432361

2344-
return vec_dot_q5_1_q8_1_impl<VDR_Q5_1_Q8_1_MMVQ>(vl, vh, u, bq5_1->dm, bq8_1->ds);
2362+
const float d = Q5_1D(bq4_1->dm);
2363+
const float m = Q5_1M(bq4_1->dm);
2364+
2365+
const float2 dm = {d, m};
2366+
2367+
return vec_dot_q5_1_q8_1_impl<VDR_Q5_1_Q8_1_MMVQ>(vl, vh, u, dm, bq8_1->ds);
23452368
}
23462369

23472370
template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q5_1(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {

ggml-metal.m

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -697,6 +697,9 @@ void ggml_metal_graph_compute(
697697
} break;
698698
case GGML_OP_MUL:
699699
{
700+
GGML_ASSERT(ne00 % 4 == 0);
701+
const int64_t nb = ne00/4;
702+
700703
if (ggml_nelements(src1) == ne10) {
701704
// src1 is a row
702705
[encoder setComputePipelineState:ctx->pipeline_mul_row];
@@ -706,9 +709,9 @@ void ggml_metal_graph_compute(
706709
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
707710
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
708711
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
709-
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
712+
[encoder setBytes:&nb length:sizeof(nb) atIndex:3];
710713

711-
const int64_t n = ggml_nelements(dst);
714+
const int64_t n = ggml_nelements(dst)/4;
712715

713716
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
714717
} break;

ggml-metal.metal

Lines changed: 40 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,22 @@ using namespace metal;
44

55
#define MAX(x, y) ((x) > (y) ? (x) : (y))
66

7+
#define Q4_0DM (1.0f/8.0f)
8+
#define Q4_0D(x) (((x)*Q4_0DM) / 127.0f)
79
#define QK4_0 32
810
#define QR4_0 2
911
typedef struct {
10-
half d; // delta
12+
int8_t d; // delta
1113
uint8_t qs[QK4_0 / 2]; // nibbles / quants
1214
} block_q4_0;
1315

16+
#define Q4_1DM (2.0f/15.0f)
17+
#define Q4_1MM (2.0f )
18+
#define Q4_1D(x) ( (((x) & 0xFF)*Q4_1DM) / 255.0f)
19+
#define Q4_1M(x) (-1.0f + (((x) >> 8)*Q4_1MM) / 255.0f)
1420
#define QK4_1 32
1521
typedef struct {
16-
half d; // delta
17-
half m; // min
22+
uint16_t dm;
1823
uint8_t qs[QK4_1 / 2]; // nibbles / quants
1924
} block_q4_1;
2025

@@ -44,22 +49,22 @@ kernel void kernel_add_row(
4449
}
4550

4651
kernel void kernel_mul(
47-
device const float * src0,
48-
device const float * src1,
49-
device float * dst,
52+
device const float4 * src0,
53+
device const float4 * src1,
54+
device float4 * dst,
5055
uint tpig[[thread_position_in_grid]]) {
5156
dst[tpig] = src0[tpig] * src1[tpig];
5257
}
5358

5459
// assumption: src1 is a row
5560
// broadcast src1 into src0
5661
kernel void kernel_mul_row(
57-
device const float * src0,
58-
device const float * src1,
59-
device float * dst,
60-
constant int64_t & ne00,
62+
device const float4 * src0,
63+
device const float4 * src1,
64+
device float4 * dst,
65+
constant int64_t & nb,
6166
uint tpig[[thread_position_in_grid]]) {
62-
dst[tpig] = src0[tpig] * src1[tpig % ne00];
67+
dst[tpig] = src0[tpig] * src1[tpig % nb];
6368
}
6469

6570
kernel void kernel_scale(
@@ -314,14 +319,18 @@ kernel void kernel_rms_norm(
314319
// we assume that the yl's have been multiplied with the appropriate scale factor
315320
// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
316321
inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl, int il) {
317-
float d = qb_curr->d;
322+
float d = Q4_0D(qb_curr->d);
318323
float2 acc = 0.f;
319-
device const uint16_t * qs = ((device const uint16_t *)qb_curr + 1 + il/2);
324+
device const uint8_t * qs = ((device const uint8_t *)qb_curr->qs + il);
325+
uint16_t qs16;
320326
for (int i = 0; i < 8; i+=2) {
321-
acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F)
322-
+ yl[i + 1] * (qs[i / 2] & 0x0F00);
323-
acc[1] += yl[i + 8] * (qs[i / 2] & 0x00F0)
324-
+ yl[i + 9] * (qs[i / 2] & 0xF000);
327+
qs16 = qs[i+1];
328+
qs16 <<= 8;
329+
qs16 |= qs[i];
330+
acc[0] += yl[i + 0] * (qs16 & 0x000F)
331+
+ yl[i + 1] * (qs16 & 0x0F00);
332+
acc[1] += yl[i + 8] * (qs16 & 0x00F0)
333+
+ yl[i + 9] * (qs16 & 0xF000);
325334
}
326335
return d * (sumy * -8.f + acc[0] + acc[1]);
327336
}
@@ -331,9 +340,9 @@ inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thre
331340
// we assume that the yl's have been multiplied with the appropriate scale factor
332341
// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
333342
inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thread float * yl, int il) {
334-
float d = qb_curr->d;
335-
float m = qb_curr->m;
336-
device const uint16_t * qs = ((device const uint16_t *)qb_curr + 2 + il/2);
343+
float d = Q4_1D(qb_curr->dm);
344+
float m = Q4_1M(qb_curr->dm);
345+
device const uint16_t * qs = ((device const uint16_t *)qb_curr + 1 + il/2);
337346
float2 acc = 0.f;
338347
for (int i = 0; i < 8; i+=2) {
339348
acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F)
@@ -1686,23 +1695,27 @@ void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg)
16861695

16871696
template <typename type4x4>
16881697
void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) {
1689-
device const uint16_t * qs = ((device const uint16_t *)xb + 1);
1690-
const half d = il ? (xb->d / 16.h) : xb->d;
1698+
device const uint8_t * qs = ((device const uint8_t *)xb->qs);
1699+
const half d = il ? (Q4_0D(xb->d) / 16.h) : Q4_0D(xb->d);
16911700
const half m = il ? ( -8.h * 16.h) : -8.h;
16921701
const ushort mask0 = il ? 0x00F0 : 0x000F;
16931702
const ushort mask1 = il ? 0xF000 : 0x0F00;
16941703

1704+
uint16_t qs16;
16951705
for (int i=0;i<8;i++) {
1696-
reg[i/2][2*(i%2)] = (((qs[i] & mask0) ) + m) * d;
1697-
reg[i/2][2*(i%2)+1] = (((qs[i] & mask1) >> 8) + m) * d;
1706+
qs16 = qs[2*i+1];
1707+
qs16 <<= 8;
1708+
qs16 |= qs[2*i];
1709+
reg[i/2][2*(i%2)] = (((qs16 & mask0) ) + m) * d;
1710+
reg[i/2][2*(i%2)+1] = (((qs16 & mask1) >> 8) + m) * d;
16981711
}
16991712
}
17001713

17011714
template <typename type4x4>
17021715
void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg) {
1703-
device const uint16_t * qs = ((device const uint16_t *)xb + 2);
1704-
const half d = il ? (xb->d / 16.h) : xb->d;
1705-
const half m = xb->m;
1716+
device const uint16_t * qs = ((device const uint16_t *)xb + 1);
1717+
const half d = il ? (Q4_1D(xb->dm) / 16.h) : Q4_1D(xb->dm);
1718+
const half m = Q4_1M(xb->dm);
17061719
const ushort mask0 = il ? 0x00F0 : 0x000F;
17071720
const ushort mask1 = il ? 0xF000 : 0x0F00;
17081721

0 commit comments

Comments
 (0)