Skip to content

Commit 9d69ecc

Browse files
committed
metal : add F32 -> Q4_0 copy kernel
1 parent 7864a2c commit 9d69ecc

File tree

2 files changed

+84
-0
lines changed

2 files changed

+84
-0
lines changed

ggml-metal.m

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,10 @@
119119
GGML_METAL_DECL_KERNEL(cpy_f32_f16);
120120
GGML_METAL_DECL_KERNEL(cpy_f32_f32);
121121
GGML_METAL_DECL_KERNEL(cpy_f32_q8_0);
122+
GGML_METAL_DECL_KERNEL(cpy_f32_q4_0);
123+
//GGML_METAL_DECL_KERNEL(cpy_f32_q4_1);
124+
//GGML_METAL_DECL_KERNEL(cpy_f32_q5_0);
125+
//GGML_METAL_DECL_KERNEL(cpy_f32_q5_1);
122126
GGML_METAL_DECL_KERNEL(cpy_f16_f16);
123127
GGML_METAL_DECL_KERNEL(concat);
124128
GGML_METAL_DECL_KERNEL(sqr);
@@ -326,6 +330,10 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
326330
GGML_METAL_ADD_KERNEL(cpy_f32_f16);
327331
GGML_METAL_ADD_KERNEL(cpy_f32_f32);
328332
GGML_METAL_ADD_KERNEL(cpy_f32_q8_0);
333+
GGML_METAL_ADD_KERNEL(cpy_f32_q4_0);
334+
//GGML_METAL_ADD_KERNEL(cpy_f32_q4_1);
335+
//GGML_METAL_ADD_KERNEL(cpy_f32_q5_0);
336+
//GGML_METAL_ADD_KERNEL(cpy_f32_q5_1);
329337
GGML_METAL_ADD_KERNEL(cpy_f16_f16);
330338
GGML_METAL_ADD_KERNEL(concat);
331339
GGML_METAL_ADD_KERNEL(sqr);
@@ -428,6 +436,10 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
428436
GGML_METAL_DEL_KERNEL(cpy_f32_f16);
429437
GGML_METAL_DEL_KERNEL(cpy_f32_f32);
430438
GGML_METAL_DEL_KERNEL(cpy_f32_q8_0);
439+
GGML_METAL_DEL_KERNEL(cpy_f32_q4_0);
440+
//GGML_METAL_DEL_KERNEL(cpy_f32_q4_1);
441+
//GGML_METAL_DEL_KERNEL(cpy_f32_q5_0);
442+
//GGML_METAL_DEL_KERNEL(cpy_f32_q5_1);
431443
GGML_METAL_DEL_KERNEL(cpy_f16_f16);
432444
GGML_METAL_DEL_KERNEL(concat);
433445
GGML_METAL_DEL_KERNEL(sqr);
@@ -1565,6 +1577,10 @@ void ggml_metal_graph_compute(
15651577
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f16]; break;
15661578
case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f32]; break;
15671579
case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q8_0]; break;
1580+
case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q4_0]; break;
1581+
//case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q4_1]; break;
1582+
//case GGML_TYPE_Q5_0: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q5_0]; break;
1583+
//case GGML_TYPE_Q5_1: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q5_1]; break;
15681584
default: GGML_ASSERT(false && "not implemented");
15691585
};
15701586
} break;

ggml-metal.metal

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
using namespace metal;
44

55
#define MAX(x, y) ((x) > (y) ? (x) : (y))
6+
#define MIN(x, y) ((x) < (y) ? (x) : (y))
67

78
#define QK4_0 32
89
#define QR4_0 2
@@ -1518,6 +1519,73 @@ kernel void kernel_cpy_f32_q8_0(
15181519
}
15191520
}
15201521

1522+
kernel void kernel_cpy_f32_q4_0(
1523+
device const float * src0,
1524+
device void * dst,
1525+
constant int64_t & ne00,
1526+
constant int64_t & ne01,
1527+
constant int64_t & ne02,
1528+
constant int64_t & ne03,
1529+
constant uint64_t & nb00,
1530+
constant uint64_t & nb01,
1531+
constant uint64_t & nb02,
1532+
constant uint64_t & nb03,
1533+
constant int64_t & ne0,
1534+
constant int64_t & ne1,
1535+
constant int64_t & ne2,
1536+
constant int64_t & ne3,
1537+
constant uint64_t & nb0,
1538+
constant uint64_t & nb1,
1539+
constant uint64_t & nb2,
1540+
constant uint64_t & nb3,
1541+
uint3 tgpig[[threadgroup_position_in_grid]],
1542+
uint3 tpitg[[thread_position_in_threadgroup]],
1543+
uint3 ntg[[threads_per_threadgroup]]) {
1544+
const int64_t i03 = tgpig[2];
1545+
const int64_t i02 = tgpig[1];
1546+
const int64_t i01 = tgpig[0];
1547+
1548+
const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
1549+
1550+
const int64_t i3 = n / (ne2*ne1*ne0);
1551+
const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
1552+
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
1553+
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK4_0;
1554+
1555+
device block_q4_0 * dst_data = (device block_q4_0 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
1556+
1557+
for (int64_t i00 = tpitg.x*QK4_0; i00 < ne00; i00 += ntg.x*QK4_0) {
1558+
device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
1559+
1560+
float amax = 0.0f; // absolute max
1561+
float max = 0.0f;
1562+
1563+
for (int j = 0; j < QK4_0; j++) {
1564+
const float v = src[j];
1565+
if (amax < fabs(v)) {
1566+
amax = fabs(v);
1567+
max = v;
1568+
}
1569+
}
1570+
1571+
const float d = max / -8;
1572+
const float id = d ? 1.0f/d : 0.0f;
1573+
1574+
dst_data[i00/QK4_0].d = d;
1575+
1576+
for (int j = 0; j < QK4_0/2; ++j) {
1577+
const float x0 = src[0 + j]*id;
1578+
const float x1 = src[QK4_0/2 + j]*id;
1579+
1580+
const uint8_t xi0 = MIN(15, (int8_t)(x0 + 8.5f));
1581+
const uint8_t xi1 = MIN(15, (int8_t)(x1 + 8.5f));
1582+
1583+
dst_data[i00/QK4_0].qs[j] = xi0;
1584+
dst_data[i00/QK4_0].qs[j] |= xi1 << 4;
1585+
}
1586+
}
1587+
}
1588+
15211589
kernel void kernel_concat(
15221590
device const char * src0,
15231591
device const char * src1,

0 commit comments

Comments
 (0)