Skip to content

Commit a43f2fb

Browse files
committed
metal: Copy kernels for quant to F32 conversions (#10976).
1 parent 51f311e commit a43f2fb

File tree

3 files changed

+189
-5
lines changed

3 files changed

+189
-5
lines changed

ggml/src/ggml-metal/ggml-metal-impl.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ typedef struct {
8484
} ggml_metal_kargs_repeat;
8585

8686
typedef struct {
87+
int64_t ne;
8788
int64_t ne00;
8889
int64_t ne01;
8990
int64_t ne02;

ggml/src/ggml-metal/ggml-metal.m

Lines changed: 51 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -407,6 +407,11 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
407407
GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0,
408408
GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1,
409409
GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL,
410+
GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F32,
411+
GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F32,
412+
GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F32,
413+
GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F32,
414+
GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F32,
410415
GGML_METAL_KERNEL_TYPE_CONCAT,
411416
GGML_METAL_KERNEL_TYPE_SQR,
412417
GGML_METAL_KERNEL_TYPE_SQRT,
@@ -1012,6 +1017,11 @@ @implementation GGMLMetalClass
10121017
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0, cpy_f32_q5_0, true);
10131018
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, cpy_f32_q5_1, true);
10141019
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL, cpy_f32_iq4_nl, true);
1020+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F32, cpy_q4_0_f32, true);
1021+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F32, cpy_q4_1_f32, true);
1022+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F32, cpy_q5_0_f32, true);
1023+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F32, cpy_q5_1_f32, true);
1024+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F32, cpy_q8_0_f32, true);
10151025
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT, concat, true);
10161026
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQR, sqr, true);
10171027
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQRT, sqrt, true);
@@ -1287,6 +1297,12 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
12871297
default:
12881298
return false;
12891299
}
1300+
case GGML_TYPE_Q4_0:
1301+
case GGML_TYPE_Q4_1:
1302+
case GGML_TYPE_Q5_0:
1303+
case GGML_TYPE_Q5_1:
1304+
case GGML_TYPE_Q8_0:
1305+
return (op->type == GGML_TYPE_F32);
12901306
default:
12911307
return false;
12921308
};
@@ -1615,7 +1631,10 @@ static void ggml_metal_encode_node(
16151631

16161632
const id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline;
16171633

1634+
const int64_t ne = ggml_nelements(src0);
1635+
16181636
ggml_metal_kargs_cpy args = {
1637+
/*.ne =*/ ne,
16191638
/*.ne00 =*/ ne00,
16201639
/*.ne01 =*/ ne01,
16211640
/*.ne02 =*/ ne02,
@@ -3899,10 +3918,7 @@ static void ggml_metal_encode_node(
38993918
case GGML_OP_CPY:
39003919
case GGML_OP_CONT:
39013920
{
3902-
GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0);
3903-
3904-
int nth = MIN(1024, ne00/ggml_blck_size(src0->type));
3905-
3921+
const int64_t ne = ggml_nelements(src0);
39063922
id<MTLComputePipelineState> pipeline = nil;
39073923

39083924
switch (src0t) {
@@ -3936,13 +3952,33 @@ static void ggml_metal_encode_node(
39363952
switch (dstt) {
39373953
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_BF16_F32].pipeline; break;
39383954
case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_BF16_BF16].pipeline; break;
3939-
default: GGML_ASSERT(false && "not implemented");
3955+
default: GGML_ABORT("not implemented");
39403956
};
39413957
} break;
3958+
case GGML_TYPE_Q4_0:
3959+
case GGML_TYPE_Q4_1:
3960+
case GGML_TYPE_Q5_0:
3961+
case GGML_TYPE_Q5_1:
3962+
case GGML_TYPE_Q8_0:
3963+
{
3964+
if (dstt == GGML_TYPE_F32) {
3965+
switch (src0t) {
3966+
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F32].pipeline; break;
3967+
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F32].pipeline; break;
3968+
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F32].pipeline; break;
3969+
case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F32].pipeline; break;
3970+
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F32].pipeline; break;
3971+
default: GGML_ABORT("not implemented");
3972+
}
3973+
} else {
3974+
GGML_ABORT("not implemented");
3975+
}
3976+
} break;
39423977
default: GGML_ABORT("not implemented");
39433978
}
39443979

39453980
ggml_metal_kargs_cpy args = {
3981+
/*.ne =*/ ne,
39463982
/*.ne00 =*/ ne00,
39473983
/*.ne01 =*/ ne01,
39483984
/*.ne02 =*/ ne02,
@@ -3966,7 +4002,17 @@ static void ggml_metal_encode_node(
39664002
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
39674003
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
39684004

4005+
int nth;
4006+
4007+
if (src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q5_0 || src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1) {
4008+
GGML_ASSERT(dstt == GGML_TYPE_F32);
4009+
nth = MIN(1024, ne);
4010+
} else {
4011+
GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0);
4012+
nth = MIN(1024, ne00/ggml_blck_size(src0->type));
4013+
}
39694014
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
4015+
39704016
} break;
39714017
case GGML_OP_SET:
39724018
{

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4372,6 +4372,143 @@ kernel void kernel_concat(
43724372
}
43734373
}
43744374

4375+
template<typename block_q, short qqk, void (*dequantize_func)(device const block_q *, device float *)>
4376+
kernel void kernel_cpy_q_f32(
4377+
constant ggml_metal_kargs_cpy & args,
4378+
device const char * cx [[ buffer(1) ]],
4379+
device char * cdst [[ buffer(2) ]],
4380+
uint tid [[ thread_position_in_grid ]]
4381+
)
4382+
{
4383+
// Compute the global index multiplied by QK, matching:
4384+
// i = (blockDim.x*blockIdx.x + threadIdx.x)*qk
4385+
const int i = int(tid) * qqk;
4386+
4387+
// Bounds check
4388+
if (i >= args.ne) {
4389+
return;
4390+
}
4391+
4392+
const int i03 = i/(args.ne00 * args.ne01 * args.ne02);
4393+
const int i02 = (i - i03*args.ne00*args.ne01*args.ne02 )/ (args.ne00*args.ne01);
4394+
const int i01 = (i - i03*args.ne00*args.ne01*args.ne02 - i02*args.ne01*args.ne00) / args.ne00;
4395+
const int i00 = i - i03*args.ne00*args.ne01*args.ne02 - i02*args.ne01*args.ne00 - i01*args.ne00;
4396+
const int x_offset = (i00/qqk)*args.nb00 + i01*args.nb01 + i02*args.nb02 + i03 * args.nb03;
4397+
4398+
const int i13 = i/(args.ne0 * args.ne1 * args.ne2);
4399+
const int i12 = (i - i13*args.ne0*args.ne1*args.ne2) / (args.ne0*args.ne1);
4400+
const int i11 = (i - i13*args.ne0*args.ne1*args.ne2 - i12*args.ne0*args.ne1) / args.ne0;
4401+
const int i10 = i - i13*args.ne0*args.ne1*args.ne2 - i12*args.ne0*args.ne1 - i11*args.ne0;
4402+
const int dst_offset = i10*args.nb0 + i11*args.nb1 + i12*args.nb2 + i13*args.nb3;
4403+
4404+
device const block_q * src_block = (device const block_q *)(cx + x_offset);
4405+
device float * dst = (device float *)(cdst + dst_offset);
4406+
4407+
dequantize_func(src_block, dst);
4408+
}
4409+
4410+
void dequant_q4_0_f(device const block_q4_0 * src_block, device float * dst) {
4411+
float d = float(src_block->d);
4412+
const float shift = 8.0f;
4413+
4414+
// Unpack 2 x 4-bit values per byte.
4415+
for (int j = 0; j < QK4_0/2; j++) {
4416+
uint8_t q = src_block->qs[j];
4417+
uint8_t q0 = q & 0x0F;
4418+
uint8_t q1 = (q >> 4) & 0x0F;
4419+
dst[j] = (float(q0) - shift) * d;
4420+
dst[j + QK4_0/2] = (float(q1) - shift) * d;
4421+
}
4422+
}
4423+
4424+
void dequant_q4_1_f(device const block_q4_1 * src_block, device float * dst) {
4425+
float d = float(src_block->d);
4426+
float vmin = float(src_block->m);
4427+
4428+
for (int j = 0; j < QK4_1/2; j++) {
4429+
uint8_t q = src_block->qs[j];
4430+
uint8_t q0 = q & 0x0F;
4431+
uint8_t q1 = (q >> 4) & 0x0F;
4432+
dst[j] = vmin + d * float(q0);
4433+
dst[j + QK4_1/2] = vmin + d * float(q1);
4434+
}
4435+
}
4436+
4437+
void dequant_q5_0_f(device const block_q5_0 * src_block, device float * dst) {
4438+
float d = float(src_block->d);
4439+
const float shift = 16.f;
4440+
4441+
// Combine the four qh bytes into a 32-bit value.
4442+
uint32_t qhVal = 0
4443+
| ((uint32_t) src_block->qh[0] << 0)
4444+
| ((uint32_t) src_block->qh[1] << 8)
4445+
| ((uint32_t) src_block->qh[2] << 16)
4446+
| ((uint32_t) src_block->qh[3] << 24);
4447+
4448+
// First half
4449+
for (int j = 0; j < QK5_0/2; j++) {
4450+
uint8_t q = src_block->qs[j];
4451+
uint8_t lowNib = q & 0x0F;
4452+
uint8_t highBit = (qhVal >> j) & 0x1;
4453+
uint8_t qVal = (highBit << 4) | lowNib;
4454+
dst[j] = (float(qVal) - shift) * d;
4455+
}
4456+
// Second half
4457+
for (int j = QK5_0/2; j < QK5_0; j++) {
4458+
int k = j - QK5_0/2;
4459+
uint8_t q = src_block->qs[k];
4460+
uint8_t hiNib = (q >> 4) & 0x0F;
4461+
uint8_t highBit = (qhVal >> j) & 0x1;
4462+
uint8_t qVal = (highBit << 4) | hiNib;
4463+
dst[j] = (float(qVal) - shift) * d;
4464+
}
4465+
}
4466+
4467+
void dequant_q5_1_f(device const block_q5_1 * src_block, device float * dst) {
4468+
float d = float(src_block->d);
4469+
float vmin = float(src_block->m);
4470+
4471+
uint32_t qhVal = 0
4472+
| ((uint32_t) src_block->qh[0] << 0)
4473+
| ((uint32_t) src_block->qh[1] << 8)
4474+
| ((uint32_t) src_block->qh[2] << 16)
4475+
| ((uint32_t) src_block->qh[3] << 24);
4476+
4477+
// First half
4478+
for (int j = 0; j < QK5_1/2; j++) {
4479+
uint8_t q = src_block->qs[j];
4480+
uint8_t lowNib = q & 0x0F;
4481+
uint8_t highBit = (qhVal >> j) & 0x1;
4482+
uint8_t qVal = (highBit << 4) | lowNib;
4483+
dst[j] = vmin + d * float(qVal);
4484+
}
4485+
// Second half
4486+
for (int j = QK5_1/2; j < QK5_1; j++) {
4487+
int k = j - QK5_1/2;
4488+
uint8_t q = src_block->qs[k];
4489+
uint8_t hiNib = (q >> 4) & 0x0F;
4490+
uint8_t highBit = (qhVal >> j) & 0x1;
4491+
uint8_t qVal = (highBit << 4) | hiNib;
4492+
dst[j] = vmin + d * float(qVal);
4493+
}
4494+
}
4495+
4496+
void dequant_q8_0_f(device const block_q8_0 * src_block, device float * dst) {
4497+
const float d = (float)src_block->d;
4498+
4499+
for (int j = 0; j < QK8_0; j++) {
4500+
dst[j] = src_block->qs[j] * d;
4501+
}
4502+
}
4503+
4504+
typedef decltype(kernel_cpy_q_f32<block_q4_0, QK4_0, dequant_q4_0_f>) cpy_q_t;
4505+
4506+
template [[host_name("kernel_cpy_q4_0_f32")]] kernel cpy_q_t kernel_cpy_q_f32<block_q4_0, QK4_0, dequant_q4_0_f>;
4507+
template [[host_name("kernel_cpy_q4_1_f32")]] kernel cpy_q_t kernel_cpy_q_f32<block_q4_1, QK4_1, dequant_q4_1_f>;
4508+
template [[host_name("kernel_cpy_q5_0_f32")]] kernel cpy_q_t kernel_cpy_q_f32<block_q5_0, QK5_0, dequant_q5_0_f>;
4509+
template [[host_name("kernel_cpy_q5_1_f32")]] kernel cpy_q_t kernel_cpy_q_f32<block_q5_1, QK5_1, dequant_q5_1_f>;
4510+
template [[host_name("kernel_cpy_q8_0_f32")]] kernel cpy_q_t kernel_cpy_q_f32<block_q8_0, QK8_0, dequant_q8_0_f>;
4511+
43754512
template<typename args_t>
43764513
void kernel_mul_mv_q2_K_f32_impl(
43774514
args_t args,

0 commit comments

Comments
 (0)