Skip to content

Commit c642a56

Browse files
committed
metal : use dequantize_q templates
1 parent 9d00bc2 commit c642a56

File tree

3 files changed

+95
-173
lines changed

3 files changed

+95
-173
lines changed

ggml/src/ggml-metal/ggml-metal-impl.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,6 @@ typedef struct {
8484
} ggml_metal_kargs_repeat;
8585

8686
typedef struct {
87-
int64_t ne;
8887
int64_t ne00;
8988
int64_t ne01;
9089
int64_t ne02;

ggml/src/ggml-metal/ggml-metal.m

Lines changed: 52 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -408,10 +408,15 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
408408
GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1,
409409
GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL,
410410
GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F32,
411+
GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F16,
411412
GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F32,
413+
GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F16,
412414
GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F32,
415+
GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F16,
413416
GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F32,
417+
GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F16,
414418
GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F32,
419+
GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F16,
415420
GGML_METAL_KERNEL_TYPE_CONCAT,
416421
GGML_METAL_KERNEL_TYPE_SQR,
417422
GGML_METAL_KERNEL_TYPE_SQRT,
@@ -1018,10 +1023,15 @@ @implementation GGMLMetalClass
10181023
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, cpy_f32_q5_1, true);
10191024
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL, cpy_f32_iq4_nl, true);
10201025
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F32, cpy_q4_0_f32, true);
1026+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F16, cpy_q4_0_f16, true);
10211027
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F32, cpy_q4_1_f32, true);
1028+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F16, cpy_q4_1_f16, true);
10221029
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F32, cpy_q5_0_f32, true);
1030+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F16, cpy_q5_0_f16, true);
10231031
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F32, cpy_q5_1_f32, true);
1032+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F16, cpy_q5_1_f16, true);
10241033
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F32, cpy_q8_0_f32, true);
1034+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F16, cpy_q8_0_f16, true);
10251035
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT, concat, true);
10261036
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQR, sqr, true);
10271037
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQRT, sqrt, true);
@@ -1302,7 +1312,13 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
13021312
case GGML_TYPE_Q5_0:
13031313
case GGML_TYPE_Q5_1:
13041314
case GGML_TYPE_Q8_0:
1305-
return (op->type == GGML_TYPE_F32);
1315+
switch (op->type) {
1316+
case GGML_TYPE_F32:
1317+
case GGML_TYPE_F16:
1318+
return true;
1319+
default:
1320+
return false;
1321+
}
13061322
default:
13071323
return false;
13081324
};
@@ -1634,7 +1650,6 @@ static void ggml_metal_encode_node(
16341650
const int64_t ne = ggml_nelements(src0);
16351651

16361652
ggml_metal_kargs_cpy args = {
1637-
/*.ne =*/ ne,
16381653
/*.ne00 =*/ ne00,
16391654
/*.ne01 =*/ ne01,
16401655
/*.ne02 =*/ ne02,
@@ -3918,7 +3933,6 @@ static void ggml_metal_encode_node(
39183933
case GGML_OP_CPY:
39193934
case GGML_OP_CONT:
39203935
{
3921-
const int64_t ne = ggml_nelements(src0);
39223936
id<MTLComputePipelineState> pipeline = nil;
39233937

39243938
switch (src0t) {
@@ -3956,29 +3970,49 @@ static void ggml_metal_encode_node(
39563970
};
39573971
} break;
39583972
case GGML_TYPE_Q4_0:
3973+
{
3974+
switch (dstt) {
3975+
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F32].pipeline; break;
3976+
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F16].pipeline; break;
3977+
default: GGML_ABORT("not implemented");
3978+
};
3979+
} break;
39593980
case GGML_TYPE_Q4_1:
3981+
{
3982+
switch (dstt) {
3983+
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F32].pipeline; break;
3984+
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F16].pipeline; break;
3985+
default: GGML_ABORT("not implemented");
3986+
};
3987+
} break;
39603988
case GGML_TYPE_Q5_0:
3989+
{
3990+
switch (dstt) {
3991+
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F32].pipeline; break;
3992+
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F16].pipeline; break;
3993+
default: GGML_ABORT("not implemented");
3994+
};
3995+
} break;
39613996
case GGML_TYPE_Q5_1:
3997+
{
3998+
switch (dstt) {
3999+
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F32].pipeline; break;
4000+
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F16].pipeline; break;
4001+
default: GGML_ABORT("not implemented");
4002+
};
4003+
} break;
39624004
case GGML_TYPE_Q8_0:
39634005
{
3964-
if (dstt == GGML_TYPE_F32) {
3965-
switch (src0t) {
3966-
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F32].pipeline; break;
3967-
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F32].pipeline; break;
3968-
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F32].pipeline; break;
3969-
case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F32].pipeline; break;
3970-
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F32].pipeline; break;
3971-
default: GGML_ABORT("not implemented");
3972-
}
3973-
} else {
3974-
GGML_ABORT("not implemented");
3975-
}
4006+
switch (dstt) {
4007+
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F32].pipeline; break;
4008+
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F16].pipeline; break;
4009+
default: GGML_ABORT("not implemented");
4010+
};
39764011
} break;
39774012
default: GGML_ABORT("not implemented");
39784013
}
39794014

39804015
ggml_metal_kargs_cpy args = {
3981-
/*.ne =*/ ne,
39824016
/*.ne00 =*/ ne00,
39834017
/*.ne01 =*/ ne01,
39844018
/*.ne02 =*/ ne02,
@@ -4002,19 +4036,9 @@ static void ggml_metal_encode_node(
40024036
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
40034037
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
40044038

4005-
int nth;
4039+
GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0);
4040+
int nth = MIN(1024, ne00/ggml_blck_size(src0->type));
40064041

4007-
if ( src0t == GGML_TYPE_Q4_0
4008-
|| src0t == GGML_TYPE_Q4_1
4009-
|| src0t == GGML_TYPE_Q5_0
4010-
|| src0t == GGML_TYPE_Q5_1
4011-
|| src0t == GGML_TYPE_Q8_0) {
4012-
GGML_ASSERT(dstt == GGML_TYPE_F32);
4013-
nth = MIN(1024, ne);
4014-
} else {
4015-
GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0);
4016-
nth = MIN(1024, ne00/ggml_blck_size(src0->type));
4017-
}
40184042
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
40194043

40204044
} break;

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 43 additions & 144 deletions
Original file line numberDiff line numberDiff line change
@@ -4341,6 +4341,49 @@ kernel void kernel_cpy_f32_iq4_nl(
43414341
}
43424342
}
43434343

4344+
template<typename T4x4, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread T4x4 &)>
4345+
kernel void kernel_cpy_q_f32(
4346+
constant ggml_metal_kargs_cpy & args,
4347+
device const char * src0,
4348+
device char * dst,
4349+
uint3 tgpig[[threadgroup_position_in_grid]],
4350+
ushort3 tpitg[[thread_position_in_threadgroup]],
4351+
ushort3 ntg[[threads_per_threadgroup]]) {
4352+
const int i03 = tgpig[2];
4353+
const int i02 = tgpig[1];
4354+
const int i01 = tgpig[0];
4355+
4356+
const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
4357+
4358+
const int64_t i3 = n/(args.ne2*args.ne1*args.ne0);
4359+
const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0)/(args.ne1*args.ne0);
4360+
const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0)/args.ne0;
4361+
const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0);
4362+
4363+
device const block_q * src_data = (device const block_q *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);
4364+
device T4x4 * dst_data = (device T4x4 *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
4365+
4366+
for (int64_t i00 = tpitg.x; i00 < args.ne00/16; i00 += ntg.x) {
4367+
T4x4 temp;
4368+
dequantize_func(src_data + i00/nl, i00%nl, temp);
4369+
dst_data[i00] = temp;
4370+
}
4371+
}
4372+
4373+
typedef decltype(kernel_cpy_q_f32<float4x4, block_q4_0, 2, dequantize_q4_0>) cpy_q_f_t;
4374+
4375+
template [[host_name("kernel_cpy_q4_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q4_0, 2, dequantize_q4_0>;
4376+
template [[host_name("kernel_cpy_q4_1_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q4_1, 2, dequantize_q4_1>;
4377+
template [[host_name("kernel_cpy_q5_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q5_0, 2, dequantize_q5_0>;
4378+
template [[host_name("kernel_cpy_q5_1_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q5_1, 2, dequantize_q5_1>;
4379+
template [[host_name("kernel_cpy_q8_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q8_0, 2, dequantize_q8_0>;
4380+
4381+
template [[host_name("kernel_cpy_q4_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q4_0, 2, dequantize_q4_0>;
4382+
template [[host_name("kernel_cpy_q4_1_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q4_1, 2, dequantize_q4_1>;
4383+
template [[host_name("kernel_cpy_q5_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q5_0, 2, dequantize_q5_0>;
4384+
template [[host_name("kernel_cpy_q5_1_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q5_1, 2, dequantize_q5_1>;
4385+
template [[host_name("kernel_cpy_q8_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q8_0, 2, dequantize_q8_0>;
4386+
43444387
kernel void kernel_concat(
43454388
constant ggml_metal_kargs_concat & args,
43464389
device const char * src0,
@@ -4372,150 +4415,6 @@ kernel void kernel_concat(
43724415
}
43734416
}
43744417

4375-
template<typename block_q, short qqk, void (*dequantize_func)(device const block_q *, device float *)>
4376-
kernel void kernel_cpy_q_f32(
4377-
constant ggml_metal_kargs_cpy & args,
4378-
device const char * cx [[ buffer(1) ]],
4379-
device char * cdst [[ buffer(2) ]],
4380-
uint tid [[ thread_position_in_grid ]]
4381-
)
4382-
{
4383-
// Compute the global index multiplied by QK, matching:
4384-
// i = (blockDim.x*blockIdx.x + threadIdx.x)*qk
4385-
const int i = int(tid) * qqk;
4386-
4387-
// Bounds check
4388-
if (i >= args.ne) {
4389-
return;
4390-
}
4391-
4392-
const int i03 = i/(args.ne00 * args.ne01 * args.ne02);
4393-
const int i02 = (i - i03*args.ne00*args.ne01*args.ne02 )/ (args.ne00*args.ne01);
4394-
const int i01 = (i - i03*args.ne00*args.ne01*args.ne02 - i02*args.ne01*args.ne00) / args.ne00;
4395-
const int i00 = i - i03*args.ne00*args.ne01*args.ne02 - i02*args.ne01*args.ne00 - i01*args.ne00;
4396-
const int x_offset = (i00/qqk)*args.nb00 + i01*args.nb01 + i02*args.nb02 + i03 * args.nb03;
4397-
4398-
const int i13 = i/(args.ne0 * args.ne1 * args.ne2);
4399-
const int i12 = (i - i13*args.ne0*args.ne1*args.ne2) / (args.ne0*args.ne1);
4400-
const int i11 = (i - i13*args.ne0*args.ne1*args.ne2 - i12*args.ne0*args.ne1) / args.ne0;
4401-
const int i10 = i - i13*args.ne0*args.ne1*args.ne2 - i12*args.ne0*args.ne1 - i11*args.ne0;
4402-
const int dst_offset = i10*args.nb0 + i11*args.nb1 + i12*args.nb2 + i13*args.nb3;
4403-
4404-
device const block_q * src_block = (device const block_q *)(cx + x_offset);
4405-
device float * dst = (device float *)(cdst + dst_offset);
4406-
4407-
dequantize_func(src_block, dst);
4408-
}
4409-
4410-
void dequant_q4_0_f(device const block_q4_0 * src_block, device float * dst) {
4411-
float d = float(src_block->d);
4412-
const float shift = 8.0f;
4413-
4414-
// Unpack 2 x 4-bit values per byte.
4415-
#pragma unroll(16)
4416-
for (int j = 0; j < QK4_0/2; j++) {
4417-
uint8_t q = src_block->qs[j];
4418-
uint8_t q0 = q & 0x0F;
4419-
uint8_t q1 = (q >> 4) & 0x0F;
4420-
dst[j] = (float(q0) - shift) * d;
4421-
dst[j + QK4_0/2] = (float(q1) - shift) * d;
4422-
}
4423-
}
4424-
4425-
void dequant_q4_1_f(device const block_q4_1 * src_block, device float * dst) {
4426-
float d = float(src_block->d);
4427-
float vmin = float(src_block->m);
4428-
4429-
#pragma unroll(16)
4430-
for (int j = 0; j < QK4_1/2; j++) {
4431-
uint8_t q = src_block->qs[j];
4432-
uint8_t q0 = q & 0x0F;
4433-
uint8_t q1 = (q >> 4) & 0x0F;
4434-
dst[j] = vmin + d * float(q0);
4435-
dst[j + QK4_1/2] = vmin + d * float(q1);
4436-
}
4437-
}
4438-
4439-
void dequant_q5_0_f(device const block_q5_0 * src_block, device float * dst) {
4440-
float d = float(src_block->d);
4441-
const float shift = 16.f;
4442-
4443-
// Combine the four qh bytes into a 32-bit value.
4444-
uint32_t qhVal = 0
4445-
| ((uint32_t) src_block->qh[0] << 0)
4446-
| ((uint32_t) src_block->qh[1] << 8)
4447-
| ((uint32_t) src_block->qh[2] << 16)
4448-
| ((uint32_t) src_block->qh[3] << 24);
4449-
4450-
// First half
4451-
#pragma unroll(16)
4452-
for (int j = 0; j < QK5_0/2; j++) {
4453-
uint8_t q = src_block->qs[j];
4454-
uint8_t lowNib = q & 0x0F;
4455-
uint8_t highBit = (qhVal >> j) & 0x1;
4456-
uint8_t qVal = (highBit << 4) | lowNib;
4457-
dst[j] = (float(qVal) - shift) * d;
4458-
}
4459-
// Second half
4460-
#pragma unroll(16)
4461-
for (int j = QK5_0/2; j < QK5_0; j++) {
4462-
int k = j - QK5_0/2;
4463-
uint8_t q = src_block->qs[k];
4464-
uint8_t hiNib = (q >> 4) & 0x0F;
4465-
uint8_t highBit = (qhVal >> j) & 0x1;
4466-
uint8_t qVal = (highBit << 4) | hiNib;
4467-
dst[j] = (float(qVal) - shift) * d;
4468-
}
4469-
}
4470-
4471-
void dequant_q5_1_f(device const block_q5_1 * src_block, device float * dst) {
4472-
float d = float(src_block->d);
4473-
float vmin = float(src_block->m);
4474-
4475-
uint32_t qhVal = 0
4476-
| ((uint32_t) src_block->qh[0] << 0)
4477-
| ((uint32_t) src_block->qh[1] << 8)
4478-
| ((uint32_t) src_block->qh[2] << 16)
4479-
| ((uint32_t) src_block->qh[3] << 24);
4480-
4481-
// First half
4482-
#pragma unroll(16)
4483-
for (int j = 0; j < QK5_1/2; j++) {
4484-
uint8_t q = src_block->qs[j];
4485-
uint8_t lowNib = q & 0x0F;
4486-
uint8_t highBit = (qhVal >> j) & 0x1;
4487-
uint8_t qVal = (highBit << 4) | lowNib;
4488-
dst[j] = vmin + d * float(qVal);
4489-
}
4490-
// Second half
4491-
#pragma unroll(16)
4492-
for (int j = QK5_1/2; j < QK5_1; j++) {
4493-
int k = j - QK5_1/2;
4494-
uint8_t q = src_block->qs[k];
4495-
uint8_t hiNib = (q >> 4) & 0x0F;
4496-
uint8_t highBit = (qhVal >> j) & 0x1;
4497-
uint8_t qVal = (highBit << 4) | hiNib;
4498-
dst[j] = vmin + d * float(qVal);
4499-
}
4500-
}
4501-
4502-
void dequant_q8_0_f(device const block_q8_0 * src_block, device float * dst) {
4503-
const float d = (float)src_block->d;
4504-
4505-
#pragma unroll(32)
4506-
for (int j = 0; j < QK8_0; j++) {
4507-
dst[j] = src_block->qs[j] * d;
4508-
}
4509-
}
4510-
4511-
typedef decltype(kernel_cpy_q_f32<block_q4_0, QK4_0, dequant_q4_0_f>) cpy_q_t;
4512-
4513-
template [[host_name("kernel_cpy_q4_0_f32")]] kernel cpy_q_t kernel_cpy_q_f32<block_q4_0, QK4_0, dequant_q4_0_f>;
4514-
template [[host_name("kernel_cpy_q4_1_f32")]] kernel cpy_q_t kernel_cpy_q_f32<block_q4_1, QK4_1, dequant_q4_1_f>;
4515-
template [[host_name("kernel_cpy_q5_0_f32")]] kernel cpy_q_t kernel_cpy_q_f32<block_q5_0, QK5_0, dequant_q5_0_f>;
4516-
template [[host_name("kernel_cpy_q5_1_f32")]] kernel cpy_q_t kernel_cpy_q_f32<block_q5_1, QK5_1, dequant_q5_1_f>;
4517-
template [[host_name("kernel_cpy_q8_0_f32")]] kernel cpy_q_t kernel_cpy_q_f32<block_q8_0, QK8_0, dequant_q8_0_f>;
4518-
45194418
template<typename args_t>
45204419
void kernel_mul_mv_q2_K_f32_impl(
45214420
args_t args,

0 commit comments

Comments
 (0)