Skip to content

Commit be1542e

Browse files
committed
metal : use dequantize_q templates
1 parent 9d00bc2 commit be1542e

File tree

3 files changed

+95
-175
lines changed

3 files changed

+95
-175
lines changed

ggml/src/ggml-metal/ggml-metal-impl.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,6 @@ typedef struct {
8484
} ggml_metal_kargs_repeat;
8585

8686
typedef struct {
87-
int64_t ne;
8887
int64_t ne00;
8988
int64_t ne01;
9089
int64_t ne02;

ggml/src/ggml-metal/ggml-metal.m

Lines changed: 52 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -408,10 +408,15 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
408408
GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1,
409409
GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL,
410410
GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F32,
411+
GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F16,
411412
GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F32,
413+
GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F16,
412414
GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F32,
415+
GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F16,
413416
GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F32,
417+
GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F16,
414418
GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F32,
419+
GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F16,
415420
GGML_METAL_KERNEL_TYPE_CONCAT,
416421
GGML_METAL_KERNEL_TYPE_SQR,
417422
GGML_METAL_KERNEL_TYPE_SQRT,
@@ -1018,10 +1023,15 @@ @implementation GGMLMetalClass
10181023
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, cpy_f32_q5_1, true);
10191024
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL, cpy_f32_iq4_nl, true);
10201025
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F32, cpy_q4_0_f32, true);
1026+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F16, cpy_q4_0_f16, true);
10211027
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F32, cpy_q4_1_f32, true);
1028+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F16, cpy_q4_1_f16, true);
10221029
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F32, cpy_q5_0_f32, true);
1030+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F16, cpy_q5_0_f16, true);
10231031
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F32, cpy_q5_1_f32, true);
1032+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F16, cpy_q5_1_f16, true);
10241033
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F32, cpy_q8_0_f32, true);
1034+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F16, cpy_q8_0_f16, true);
10251035
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT, concat, true);
10261036
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQR, sqr, true);
10271037
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQRT, sqrt, true);
@@ -1302,7 +1312,13 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
13021312
case GGML_TYPE_Q5_0:
13031313
case GGML_TYPE_Q5_1:
13041314
case GGML_TYPE_Q8_0:
1305-
return (op->type == GGML_TYPE_F32);
1315+
switch (op->type) {
1316+
case GGML_TYPE_F32:
1317+
case GGML_TYPE_F16:
1318+
return true;
1319+
default:
1320+
return false;
1321+
}
13061322
default:
13071323
return false;
13081324
};
@@ -1631,10 +1647,7 @@ static void ggml_metal_encode_node(
16311647

16321648
const id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline;
16331649

1634-
const int64_t ne = ggml_nelements(src0);
1635-
16361650
ggml_metal_kargs_cpy args = {
1637-
/*.ne =*/ ne,
16381651
/*.ne00 =*/ ne00,
16391652
/*.ne01 =*/ ne01,
16401653
/*.ne02 =*/ ne02,
@@ -3918,7 +3931,6 @@ static void ggml_metal_encode_node(
39183931
case GGML_OP_CPY:
39193932
case GGML_OP_CONT:
39203933
{
3921-
const int64_t ne = ggml_nelements(src0);
39223934
id<MTLComputePipelineState> pipeline = nil;
39233935

39243936
switch (src0t) {
@@ -3956,29 +3968,49 @@ static void ggml_metal_encode_node(
39563968
};
39573969
} break;
39583970
case GGML_TYPE_Q4_0:
3971+
{
3972+
switch (dstt) {
3973+
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F32].pipeline; break;
3974+
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F16].pipeline; break;
3975+
default: GGML_ABORT("not implemented");
3976+
};
3977+
} break;
39593978
case GGML_TYPE_Q4_1:
3979+
{
3980+
switch (dstt) {
3981+
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F32].pipeline; break;
3982+
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F16].pipeline; break;
3983+
default: GGML_ABORT("not implemented");
3984+
};
3985+
} break;
39603986
case GGML_TYPE_Q5_0:
3987+
{
3988+
switch (dstt) {
3989+
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F32].pipeline; break;
3990+
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F16].pipeline; break;
3991+
default: GGML_ABORT("not implemented");
3992+
};
3993+
} break;
39613994
case GGML_TYPE_Q5_1:
3995+
{
3996+
switch (dstt) {
3997+
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F32].pipeline; break;
3998+
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F16].pipeline; break;
3999+
default: GGML_ABORT("not implemented");
4000+
};
4001+
} break;
39624002
case GGML_TYPE_Q8_0:
39634003
{
3964-
if (dstt == GGML_TYPE_F32) {
3965-
switch (src0t) {
3966-
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F32].pipeline; break;
3967-
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F32].pipeline; break;
3968-
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F32].pipeline; break;
3969-
case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F32].pipeline; break;
3970-
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F32].pipeline; break;
3971-
default: GGML_ABORT("not implemented");
3972-
}
3973-
} else {
3974-
GGML_ABORT("not implemented");
3975-
}
4004+
switch (dstt) {
4005+
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F32].pipeline; break;
4006+
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F16].pipeline; break;
4007+
default: GGML_ABORT("not implemented");
4008+
};
39764009
} break;
39774010
default: GGML_ABORT("not implemented");
39784011
}
39794012

39804013
ggml_metal_kargs_cpy args = {
3981-
/*.ne =*/ ne,
39824014
/*.ne00 =*/ ne00,
39834015
/*.ne01 =*/ ne01,
39844016
/*.ne02 =*/ ne02,
@@ -4002,19 +4034,9 @@ static void ggml_metal_encode_node(
40024034
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
40034035
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
40044036

4005-
int nth;
4037+
GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0);
4038+
int nth = MIN(1024, ne00/ggml_blck_size(src0->type));
40064039

4007-
if ( src0t == GGML_TYPE_Q4_0
4008-
|| src0t == GGML_TYPE_Q4_1
4009-
|| src0t == GGML_TYPE_Q5_0
4010-
|| src0t == GGML_TYPE_Q5_1
4011-
|| src0t == GGML_TYPE_Q8_0) {
4012-
GGML_ASSERT(dstt == GGML_TYPE_F32);
4013-
nth = MIN(1024, ne);
4014-
} else {
4015-
GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0);
4016-
nth = MIN(1024, ne00/ggml_blck_size(src0->type));
4017-
}
40184040
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
40194041

40204042
} break;

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 43 additions & 144 deletions
Original file line numberDiff line numberDiff line change
@@ -4341,6 +4341,49 @@ kernel void kernel_cpy_f32_iq4_nl(
43414341
}
43424342
}
43434343

4344+
template<typename T4x4, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread T4x4 &)>
4345+
kernel void kernel_cpy_q_f32(
4346+
constant ggml_metal_kargs_cpy & args,
4347+
device const char * src0,
4348+
device char * dst,
4349+
uint3 tgpig[[threadgroup_position_in_grid]],
4350+
ushort3 tpitg[[thread_position_in_threadgroup]],
4351+
ushort3 ntg[[threads_per_threadgroup]]) {
4352+
const int i03 = tgpig[2];
4353+
const int i02 = tgpig[1];
4354+
const int i01 = tgpig[0];
4355+
4356+
const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
4357+
4358+
const int64_t i3 = n/(args.ne2*args.ne1*args.ne0);
4359+
const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0)/(args.ne1*args.ne0);
4360+
const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0)/args.ne0;
4361+
const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0);
4362+
4363+
device const block_q * src_data = (device const block_q *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);
4364+
device T4x4 * dst_data = (device T4x4 *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
4365+
4366+
for (int64_t i00 = tpitg.x; i00 < args.ne00/16; i00 += ntg.x) {
4367+
T4x4 temp;
4368+
dequantize_func(src_data + i00/nl, i00%nl, temp);
4369+
dst_data[i00] = temp;
4370+
}
4371+
}
4372+
4373+
typedef decltype(kernel_cpy_q_f32<float4x4, block_q4_0, 2, dequantize_q4_0>) cpy_q_f_t;
4374+
4375+
template [[host_name("kernel_cpy_q4_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q4_0, 2, dequantize_q4_0>;
4376+
template [[host_name("kernel_cpy_q4_1_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q4_1, 2, dequantize_q4_1>;
4377+
template [[host_name("kernel_cpy_q5_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q5_0, 2, dequantize_q5_0>;
4378+
template [[host_name("kernel_cpy_q5_1_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q5_1, 2, dequantize_q5_1>;
4379+
template [[host_name("kernel_cpy_q8_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q8_0, 2, dequantize_q8_0>;
4380+
4381+
template [[host_name("kernel_cpy_q4_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q4_0, 2, dequantize_q4_0>;
4382+
template [[host_name("kernel_cpy_q4_1_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q4_1, 2, dequantize_q4_1>;
4383+
template [[host_name("kernel_cpy_q5_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q5_0, 2, dequantize_q5_0>;
4384+
template [[host_name("kernel_cpy_q5_1_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q5_1, 2, dequantize_q5_1>;
4385+
template [[host_name("kernel_cpy_q8_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q8_0, 2, dequantize_q8_0>;
4386+
43444387
kernel void kernel_concat(
43454388
constant ggml_metal_kargs_concat & args,
43464389
device const char * src0,
@@ -4372,150 +4415,6 @@ kernel void kernel_concat(
43724415
}
43734416
}
43744417

4375-
template<typename block_q, short qqk, void (*dequantize_func)(device const block_q *, device float *)>
4376-
kernel void kernel_cpy_q_f32(
4377-
constant ggml_metal_kargs_cpy & args,
4378-
device const char * cx [[ buffer(1) ]],
4379-
device char * cdst [[ buffer(2) ]],
4380-
uint tid [[ thread_position_in_grid ]]
4381-
)
4382-
{
4383-
// Compute the global index multiplied by QK, matching:
4384-
// i = (blockDim.x*blockIdx.x + threadIdx.x)*qk
4385-
const int i = int(tid) * qqk;
4386-
4387-
// Bounds check
4388-
if (i >= args.ne) {
4389-
return;
4390-
}
4391-
4392-
const int i03 = i/(args.ne00 * args.ne01 * args.ne02);
4393-
const int i02 = (i - i03*args.ne00*args.ne01*args.ne02 )/ (args.ne00*args.ne01);
4394-
const int i01 = (i - i03*args.ne00*args.ne01*args.ne02 - i02*args.ne01*args.ne00) / args.ne00;
4395-
const int i00 = i - i03*args.ne00*args.ne01*args.ne02 - i02*args.ne01*args.ne00 - i01*args.ne00;
4396-
const int x_offset = (i00/qqk)*args.nb00 + i01*args.nb01 + i02*args.nb02 + i03 * args.nb03;
4397-
4398-
const int i13 = i/(args.ne0 * args.ne1 * args.ne2);
4399-
const int i12 = (i - i13*args.ne0*args.ne1*args.ne2) / (args.ne0*args.ne1);
4400-
const int i11 = (i - i13*args.ne0*args.ne1*args.ne2 - i12*args.ne0*args.ne1) / args.ne0;
4401-
const int i10 = i - i13*args.ne0*args.ne1*args.ne2 - i12*args.ne0*args.ne1 - i11*args.ne0;
4402-
const int dst_offset = i10*args.nb0 + i11*args.nb1 + i12*args.nb2 + i13*args.nb3;
4403-
4404-
device const block_q * src_block = (device const block_q *)(cx + x_offset);
4405-
device float * dst = (device float *)(cdst + dst_offset);
4406-
4407-
dequantize_func(src_block, dst);
4408-
}
4409-
4410-
void dequant_q4_0_f(device const block_q4_0 * src_block, device float * dst) {
4411-
float d = float(src_block->d);
4412-
const float shift = 8.0f;
4413-
4414-
// Unpack 2 x 4-bit values per byte.
4415-
#pragma unroll(16)
4416-
for (int j = 0; j < QK4_0/2; j++) {
4417-
uint8_t q = src_block->qs[j];
4418-
uint8_t q0 = q & 0x0F;
4419-
uint8_t q1 = (q >> 4) & 0x0F;
4420-
dst[j] = (float(q0) - shift) * d;
4421-
dst[j + QK4_0/2] = (float(q1) - shift) * d;
4422-
}
4423-
}
4424-
4425-
void dequant_q4_1_f(device const block_q4_1 * src_block, device float * dst) {
4426-
float d = float(src_block->d);
4427-
float vmin = float(src_block->m);
4428-
4429-
#pragma unroll(16)
4430-
for (int j = 0; j < QK4_1/2; j++) {
4431-
uint8_t q = src_block->qs[j];
4432-
uint8_t q0 = q & 0x0F;
4433-
uint8_t q1 = (q >> 4) & 0x0F;
4434-
dst[j] = vmin + d * float(q0);
4435-
dst[j + QK4_1/2] = vmin + d * float(q1);
4436-
}
4437-
}
4438-
4439-
void dequant_q5_0_f(device const block_q5_0 * src_block, device float * dst) {
4440-
float d = float(src_block->d);
4441-
const float shift = 16.f;
4442-
4443-
// Combine the four qh bytes into a 32-bit value.
4444-
uint32_t qhVal = 0
4445-
| ((uint32_t) src_block->qh[0] << 0)
4446-
| ((uint32_t) src_block->qh[1] << 8)
4447-
| ((uint32_t) src_block->qh[2] << 16)
4448-
| ((uint32_t) src_block->qh[3] << 24);
4449-
4450-
// First half
4451-
#pragma unroll(16)
4452-
for (int j = 0; j < QK5_0/2; j++) {
4453-
uint8_t q = src_block->qs[j];
4454-
uint8_t lowNib = q & 0x0F;
4455-
uint8_t highBit = (qhVal >> j) & 0x1;
4456-
uint8_t qVal = (highBit << 4) | lowNib;
4457-
dst[j] = (float(qVal) - shift) * d;
4458-
}
4459-
// Second half
4460-
#pragma unroll(16)
4461-
for (int j = QK5_0/2; j < QK5_0; j++) {
4462-
int k = j - QK5_0/2;
4463-
uint8_t q = src_block->qs[k];
4464-
uint8_t hiNib = (q >> 4) & 0x0F;
4465-
uint8_t highBit = (qhVal >> j) & 0x1;
4466-
uint8_t qVal = (highBit << 4) | hiNib;
4467-
dst[j] = (float(qVal) - shift) * d;
4468-
}
4469-
}
4470-
4471-
void dequant_q5_1_f(device const block_q5_1 * src_block, device float * dst) {
4472-
float d = float(src_block->d);
4473-
float vmin = float(src_block->m);
4474-
4475-
uint32_t qhVal = 0
4476-
| ((uint32_t) src_block->qh[0] << 0)
4477-
| ((uint32_t) src_block->qh[1] << 8)
4478-
| ((uint32_t) src_block->qh[2] << 16)
4479-
| ((uint32_t) src_block->qh[3] << 24);
4480-
4481-
// First half
4482-
#pragma unroll(16)
4483-
for (int j = 0; j < QK5_1/2; j++) {
4484-
uint8_t q = src_block->qs[j];
4485-
uint8_t lowNib = q & 0x0F;
4486-
uint8_t highBit = (qhVal >> j) & 0x1;
4487-
uint8_t qVal = (highBit << 4) | lowNib;
4488-
dst[j] = vmin + d * float(qVal);
4489-
}
4490-
// Second half
4491-
#pragma unroll(16)
4492-
for (int j = QK5_1/2; j < QK5_1; j++) {
4493-
int k = j - QK5_1/2;
4494-
uint8_t q = src_block->qs[k];
4495-
uint8_t hiNib = (q >> 4) & 0x0F;
4496-
uint8_t highBit = (qhVal >> j) & 0x1;
4497-
uint8_t qVal = (highBit << 4) | hiNib;
4498-
dst[j] = vmin + d * float(qVal);
4499-
}
4500-
}
4501-
4502-
void dequant_q8_0_f(device const block_q8_0 * src_block, device float * dst) {
4503-
const float d = (float)src_block->d;
4504-
4505-
#pragma unroll(32)
4506-
for (int j = 0; j < QK8_0; j++) {
4507-
dst[j] = src_block->qs[j] * d;
4508-
}
4509-
}
4510-
4511-
typedef decltype(kernel_cpy_q_f32<block_q4_0, QK4_0, dequant_q4_0_f>) cpy_q_t;
4512-
4513-
template [[host_name("kernel_cpy_q4_0_f32")]] kernel cpy_q_t kernel_cpy_q_f32<block_q4_0, QK4_0, dequant_q4_0_f>;
4514-
template [[host_name("kernel_cpy_q4_1_f32")]] kernel cpy_q_t kernel_cpy_q_f32<block_q4_1, QK4_1, dequant_q4_1_f>;
4515-
template [[host_name("kernel_cpy_q5_0_f32")]] kernel cpy_q_t kernel_cpy_q_f32<block_q5_0, QK5_0, dequant_q5_0_f>;
4516-
template [[host_name("kernel_cpy_q5_1_f32")]] kernel cpy_q_t kernel_cpy_q_f32<block_q5_1, QK5_1, dequant_q5_1_f>;
4517-
template [[host_name("kernel_cpy_q8_0_f32")]] kernel cpy_q_t kernel_cpy_q_f32<block_q8_0, QK8_0, dequant_q8_0_f>;
4518-
45194418
template<typename args_t>
45204419
void kernel_mul_mv_q2_K_f32_impl(
45214420
args_t args,

0 commit comments

Comments
 (0)