Skip to content

Commit 9b06a0e

Browse files
zhiyuan1iMollySophia
authored andcommitted
initial support for apple
1 parent 686899d commit 9b06a0e

File tree

2 files changed

+140
-0
lines changed

2 files changed

+140
-0
lines changed

ggml/src/ggml-metal/ggml-metal.m

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,7 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
171171
GGML_METAL_KERNEL_TYPE_NORM,
172172
GGML_METAL_KERNEL_TYPE_SSM_CONV_F32,
173173
GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32,
174+
GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32,
174175
GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32,
175176
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32,
176177
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW,
@@ -777,6 +778,7 @@ @implementation GGMLMetalClass
777778
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true);
778779
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, ssm_conv_f32, true);
779780
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32, ssm_scan_f32, true);
781+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32, rwkv_wkv6_f32, true);
780782
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, has_simdgroup_reduction);
781783
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32, mul_mv_bf16_f32, has_simdgroup_reduction && use_bfloat);
782784
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW, mul_mv_bf16_f32_1row, has_simdgroup_reduction && use_bfloat);
@@ -1237,6 +1239,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
12371239
return has_simdgroup_mm; // TODO: over-restricted for vec-kernels
12381240
case GGML_OP_SSM_CONV:
12391241
case GGML_OP_SSM_SCAN:
1242+
case GGML_OP_RWKV_WKV6:
12401243
return true;
12411244
case GGML_OP_MUL_MAT:
12421245
case GGML_OP_MUL_MAT_ID:
@@ -2140,6 +2143,57 @@ static void ggml_metal_encode_node(
21402143

21412144
[encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
21422145
} break;
2146+
case GGML_OP_RWKV_WKV6:
2147+
{
2148+
const int64_t B = dst->src[5]->ne[1];
2149+
const int64_t T = dst->src[0]->ne[3];
2150+
const int64_t C = dst->ne[0];
2151+
const int64_t H = dst->src[0]->ne[2];
2152+
2153+
GGML_ASSERT(dst->src[5]->type == GGML_TYPE_F32);
2154+
GGML_ASSERT(C % H == 0);
2155+
GGML_ASSERT(C / H == 64); // The current Metal kernel is designed for RWKV6, HEAD_SIZE == 64
2156+
2157+
size_t offs_k = 0;
2158+
size_t offs_v = 0;
2159+
size_t offs_r = 0;
2160+
size_t offs_tf = 0;
2161+
size_t offs_td = 0;
2162+
size_t offs_s = 0;
2163+
size_t offs_dst = 0;
2164+
2165+
id<MTLBuffer> id_k = dst->src[0] ? ggml_metal_get_buffer(dst->src[0], &offs_k) : nil;
2166+
id<MTLBuffer> id_v = dst->src[1] ? ggml_metal_get_buffer(dst->src[1], &offs_v) : nil;
2167+
id<MTLBuffer> id_r = dst->src[2] ? ggml_metal_get_buffer(dst->src[2], &offs_r) : nil;
2168+
id<MTLBuffer> id_tf = dst->src[3] ? ggml_metal_get_buffer(dst->src[3], &offs_tf) : nil;
2169+
id<MTLBuffer> id_td = dst->src[4] ? ggml_metal_get_buffer(dst->src[4], &offs_td) : nil;
2170+
id<MTLBuffer> id_s = dst->src[5] ? ggml_metal_get_buffer(dst->src[5], &offs_s) : nil;
2171+
id<MTLBuffer> id_dst = dst ? ggml_metal_get_buffer(dst, &offs_dst) : nil;
2172+
2173+
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32].pipeline;
2174+
2175+
id<MTLCommandBuffer> command_buffer = ctx->queue.commandBuffer;
2176+
id<MTLComputeCommandEncoder> encoder = [command_buffer computeCommandEncoder];
2177+
2178+
[encoder setComputePipelineState:pipeline];
2179+
[encoder setBuffer:id_k offset:offs_k atIndex:0];
2180+
[encoder setBuffer:id_v offset:offs_v atIndex:1];
2181+
[encoder setBuffer:id_r offset:offs_r atIndex:2];
2182+
[encoder setBuffer:id_tf offset:offs_tf atIndex:3];
2183+
[encoder setBuffer:id_td offset:offs_td atIndex:4];
2184+
[encoder setBuffer:id_s offset:offs_s atIndex:5];
2185+
[encoder setBuffer:id_dst offset:offs_dst atIndex:6];
2186+
2187+
[encoder setBytes:&B length:sizeof(B) atIndex:7];
2188+
[encoder setBytes:&T length:sizeof(T) atIndex:8];
2189+
[encoder setBytes:&C length:sizeof(C) atIndex:9];
2190+
[encoder setBytes:&H length:sizeof(H) atIndex:10];
2191+
2192+
[encoder dispatchThreadgroups:MTLSizeMake(B * H, 1, 1) threadsPerThreadgroup:MTLSizeMake(C/ H, 1, 1)];
2193+
2194+
[encoder endEncoding];
2195+
[command_buffer commit];
2196+
} break;
21432197
case GGML_OP_MUL_MAT:
21442198
{
21452199
GGML_ASSERT(ne00 == ne10);

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1366,6 +1366,92 @@ kernel void kernel_ssm_scan_f32(
13661366
}
13671367
}
13681368

1369+
kernel void kernel_rwkv_wkv6_f32(
1370+
device const float * k,
1371+
device const float * v,
1372+
device const float * r,
1373+
device const float * tf,
1374+
device const float * td,
1375+
device const float * state_in,
1376+
device float * dst,
1377+
constant uint & B,
1378+
constant uint & T,
1379+
constant uint & C,
1380+
constant uint & H,
1381+
uint3 tgpig[[threadgroup_position_in_grid]],
1382+
uint3 tpitg[[thread_position_in_threadgroup]],
1383+
uint3 ntg[[threads_per_threadgroup]]) {
1384+
1385+
const uint head_size = 64; // rwkv6
1386+
const uint batch_id = tgpig.x / H;
1387+
const uint head_id = tgpig.x % H;
1388+
const uint tid = tpitg.x;
1389+
1390+
if (batch_id >= B || head_id >= H) {
1391+
return;
1392+
}
1393+
1394+
const uint state_size = C * head_size;
1395+
const uint n_seq_tokens = T / B;
1396+
1397+
threadgroup float _k[head_size];
1398+
threadgroup float _r[head_size];
1399+
threadgroup float _tf[head_size];
1400+
threadgroup float _td[head_size];
1401+
1402+
float state[head_size];
1403+
#pragma unroll(64)
1404+
for (uint i = 0; i < head_size; i++) {
1405+
state[i] = state_in[batch_id * state_size + head_id * head_size * head_size
1406+
+ i * head_size + tid];
1407+
}
1408+
1409+
threadgroup_barrier(mem_flags::mem_threadgroup);
1410+
_tf[tid] = tf[head_id * head_size + tid];
1411+
threadgroup_barrier(mem_flags::mem_threadgroup);
1412+
1413+
const uint start_t = batch_id * n_seq_tokens * C + head_id * head_size + tid;
1414+
const uint end_t = (batch_id + 1) * n_seq_tokens * C + head_id * head_size + tid;
1415+
1416+
for (uint t = start_t; t < end_t; t += C) {
1417+
threadgroup_barrier(mem_flags::mem_threadgroup);
1418+
_k[tid] = k[t];
1419+
_r[tid] = r[t];
1420+
_td[tid] = td[t];
1421+
threadgroup_barrier(mem_flags::mem_threadgroup);
1422+
1423+
const float v_val = v[t];
1424+
float y = 0.0;
1425+
1426+
#pragma unroll(64)
1427+
for (uint j = 0; j < head_size; j += 4) {
1428+
float4 k_vec = float4(_k[j], _k[j+1], _k[j+2], _k[j+3]);
1429+
float4 r_vec = float4(_r[j], _r[j+1], _r[j+2], _r[j+3]);
1430+
float4 tf_vec = float4(_tf[j], _tf[j+1], _tf[j+2], _tf[j+3]);
1431+
float4 td_vec = float4(_td[j], _td[j+1], _td[j+2], _td[j+3]);
1432+
float4 s_vec = float4(state[j], state[j+1], state[j+2], state[j+3]);
1433+
1434+
float4 kv = k_vec * v_val;
1435+
1436+
float4 temp = tf_vec * kv + s_vec;
1437+
y += dot(r_vec, temp);
1438+
1439+
s_vec = s_vec * td_vec + kv;
1440+
state[j] = s_vec.x;
1441+
state[j+1] = s_vec.y;
1442+
state[j+2] = s_vec.z;
1443+
state[j+3] = s_vec.w;
1444+
}
1445+
1446+
dst[t] = y;
1447+
}
1448+
#pragma unroll(64)
1449+
for (uint i = 0; i < head_size; i++) {
1450+
dst[T * C + batch_id * state_size + head_id * head_size * head_size
1451+
+ i * head_size + tid] = state[i];
1452+
}
1453+
}
1454+
13691455
kernel void kernel_argmax(
13701456
device const void * x,
13711457
device int32_t * dst,

0 commit comments

Comments
 (0)