Skip to content

Commit 6b58ae9

Browse files
committed
metal : add F32 -> Q4_1 copy kernel
1 parent 9d69ecc commit 6b58ae9

File tree

2 files changed

+70
-4
lines changed

2 files changed

+70
-4
lines changed

ggml-metal.m

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@
120120
GGML_METAL_DECL_KERNEL(cpy_f32_f32);
121121
GGML_METAL_DECL_KERNEL(cpy_f32_q8_0);
122122
GGML_METAL_DECL_KERNEL(cpy_f32_q4_0);
123-
//GGML_METAL_DECL_KERNEL(cpy_f32_q4_1);
123+
GGML_METAL_DECL_KERNEL(cpy_f32_q4_1);
124124
//GGML_METAL_DECL_KERNEL(cpy_f32_q5_0);
125125
//GGML_METAL_DECL_KERNEL(cpy_f32_q5_1);
126126
GGML_METAL_DECL_KERNEL(cpy_f16_f16);
@@ -331,7 +331,7 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
331331
GGML_METAL_ADD_KERNEL(cpy_f32_f32);
332332
GGML_METAL_ADD_KERNEL(cpy_f32_q8_0);
333333
GGML_METAL_ADD_KERNEL(cpy_f32_q4_0);
334-
//GGML_METAL_ADD_KERNEL(cpy_f32_q4_1);
334+
GGML_METAL_ADD_KERNEL(cpy_f32_q4_1);
335335
//GGML_METAL_ADD_KERNEL(cpy_f32_q5_0);
336336
//GGML_METAL_ADD_KERNEL(cpy_f32_q5_1);
337337
GGML_METAL_ADD_KERNEL(cpy_f16_f16);
@@ -437,7 +437,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
437437
GGML_METAL_DEL_KERNEL(cpy_f32_f32);
438438
GGML_METAL_DEL_KERNEL(cpy_f32_q8_0);
439439
GGML_METAL_DEL_KERNEL(cpy_f32_q4_0);
440-
//GGML_METAL_DEL_KERNEL(cpy_f32_q4_1);
440+
GGML_METAL_DEL_KERNEL(cpy_f32_q4_1);
441441
//GGML_METAL_DEL_KERNEL(cpy_f32_q5_0);
442442
//GGML_METAL_DEL_KERNEL(cpy_f32_q5_1);
443443
GGML_METAL_DEL_KERNEL(cpy_f16_f16);
@@ -1578,7 +1578,7 @@ void ggml_metal_graph_compute(
15781578
case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f32]; break;
15791579
case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q8_0]; break;
15801580
case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q4_0]; break;
1581-
//case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q4_1]; break;
1581+
case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q4_1]; break;
15821582
//case GGML_TYPE_Q5_0: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q5_0]; break;
15831583
//case GGML_TYPE_Q5_1: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q5_1]; break;
15841584
default: GGML_ASSERT(false && "not implemented");

ggml-metal.metal

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1586,6 +1586,72 @@ kernel void kernel_cpy_f32_q4_0(
15861586
}
15871587
}
15881588

1589+
kernel void kernel_cpy_f32_q4_1(
1590+
device const float * src0,
1591+
device void * dst,
1592+
constant int64_t & ne00,
1593+
constant int64_t & ne01,
1594+
constant int64_t & ne02,
1595+
constant int64_t & ne03,
1596+
constant uint64_t & nb00,
1597+
constant uint64_t & nb01,
1598+
constant uint64_t & nb02,
1599+
constant uint64_t & nb03,
1600+
constant int64_t & ne0,
1601+
constant int64_t & ne1,
1602+
constant int64_t & ne2,
1603+
constant int64_t & ne3,
1604+
constant uint64_t & nb0,
1605+
constant uint64_t & nb1,
1606+
constant uint64_t & nb2,
1607+
constant uint64_t & nb3,
1608+
uint3 tgpig[[threadgroup_position_in_grid]],
1609+
uint3 tpitg[[thread_position_in_threadgroup]],
1610+
uint3 ntg[[threads_per_threadgroup]]) {
1611+
const int64_t i03 = tgpig[2];
1612+
const int64_t i02 = tgpig[1];
1613+
const int64_t i01 = tgpig[0];
1614+
1615+
const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
1616+
1617+
const int64_t i3 = n / (ne2*ne1*ne0);
1618+
const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
1619+
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
1620+
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK4_1;
1621+
1622+
device block_q4_1 * dst_data = (device block_q4_1 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
1623+
1624+
for (int64_t i00 = tpitg.x*QK4_1; i00 < ne00; i00 += ntg.x*QK4_1) {
1625+
device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
1626+
1627+
float min = FLT_MAX;
1628+
float max = -FLT_MAX;
1629+
1630+
for (int j = 0; j < QK4_1; j++) {
1631+
const float v = src[j];
1632+
if (min > v) min = v;
1633+
if (max < v) max = v;
1634+
}
1635+
1636+
const float d = (max - min) / ((1 << 4) - 1);
1637+
const float id = d ? 1.0f/d : 0.0f;
1638+
1639+
dst_data[i00/QK4_1].d = d;
1640+
dst_data[i00/QK4_1].m = min;
1641+
1642+
for (int j = 0; j < QK4_1/2; ++j) {
1643+
const float x0 = (src[0 + j] - min)*id;
1644+
const float x1 = (src[QK4_1/2 + j] - min)*id;
1645+
1646+
const uint8_t xi0 = MIN(15, (int8_t)(x0 + 0.5f));
1647+
const uint8_t xi1 = MIN(15, (int8_t)(x1 + 0.5f));
1648+
1649+
dst_data[i00/QK4_1].qs[j] = xi0;
1650+
dst_data[i00/QK4_1].qs[j] |= xi1 << 4;
1651+
}
1652+
}
1653+
}
1654+
15891655
kernel void kernel_concat(
15901656
device const char * src0,
15911657
device const char * src1,

0 commit comments

Comments
 (0)