Skip to content

Commit d11f487

Browse files
committed
WKV7 Metal
Signed-off-by: Molly Sophia <[email protected]>
1 parent 6c15983 commit d11f487

File tree

2 files changed

+141
-4
lines changed

2 files changed

+141
-4
lines changed

ggml/src/ggml-metal/ggml-metal.m

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,7 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
174174
GGML_METAL_KERNEL_TYPE_SSM_CONV_F32,
175175
GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32,
176176
GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32,
177+
GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32,
177178
GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32,
178179
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32,
179180
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW,
@@ -783,6 +784,7 @@ @implementation GGMLMetalClass
783784
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, ssm_conv_f32, true);
784785
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32, ssm_scan_f32, true);
785786
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32, rwkv_wkv6_f32, true);
787+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32, rwkv_wkv7_f32, true);
786788
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, has_simdgroup_reduction);
787789
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32, mul_mv_bf16_f32, has_simdgroup_reduction && use_bfloat);
788790
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW, mul_mv_bf16_f32_1row, has_simdgroup_reduction && use_bfloat);
@@ -1246,6 +1248,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
12461248
case GGML_OP_SSM_CONV:
12471249
case GGML_OP_SSM_SCAN:
12481250
case GGML_OP_RWKV_WKV6:
1251+
case GGML_OP_RWKV_WKV7:
12491252
return true;
12501253
case GGML_OP_MUL_MAT:
12511254
case GGML_OP_MUL_MAT_ID:
@@ -2208,6 +2211,46 @@ static void ggml_metal_encode_node(
22082211
[encoder setBytes:&C length:sizeof(C) atIndex:9];
22092212
[encoder setBytes:&H length:sizeof(H) atIndex:10];
22102213

2214+
[encoder dispatchThreadgroups:MTLSizeMake(B * H, 1, 1) threadsPerThreadgroup:MTLSizeMake(C/ H, 1, 1)];
2215+
} break;
2216+
case GGML_OP_RWKV_WKV7:
2217+
{
2218+
const int64_t B = dst->src[6]->ne[1];
2219+
const int64_t T = dst->src[0]->ne[2];
2220+
const int64_t C = dst->ne[0];
2221+
const int64_t H = dst->src[0]->ne[1];
2222+
2223+
GGML_ASSERT(dst->src[6]->type == GGML_TYPE_F32);
2224+
GGML_ASSERT(C % H == 0);
2225+
GGML_ASSERT(C / H == 64);
2226+
2227+
size_t offs_src3 = 0;
2228+
size_t offs_src4 = 0;
2229+
size_t offs_src5 = 0;
2230+
size_t offs_src6 = 0;
2231+
2232+
id<MTLBuffer> id_src3 = dst->src[3] ? ggml_metal_get_buffer(dst->src[3], &offs_src3) : nil;
2233+
id<MTLBuffer> id_src4 = dst->src[4] ? ggml_metal_get_buffer(dst->src[4], &offs_src4) : nil;
2234+
id<MTLBuffer> id_src5 = dst->src[5] ? ggml_metal_get_buffer(dst->src[5], &offs_src5) : nil;
2235+
id<MTLBuffer> id_src6 = dst->src[6] ? ggml_metal_get_buffer(dst->src[6], &offs_src6) : nil;
2236+
2237+
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32].pipeline;
2238+
2239+
[encoder setComputePipelineState:pipeline];
2240+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2241+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
2242+
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
2243+
[encoder setBuffer:id_src3 offset:offs_src3 atIndex:3];
2244+
[encoder setBuffer:id_src4 offset:offs_src4 atIndex:4];
2245+
[encoder setBuffer:id_src5 offset:offs_src5 atIndex:5];
2246+
[encoder setBuffer:id_src6 offset:offs_src6 atIndex:6];
2247+
[encoder setBuffer:id_dst offset:offs_dst atIndex:7];
2248+
2249+
[encoder setBytes:&B length:sizeof(B) atIndex:8];
2250+
[encoder setBytes:&T length:sizeof(T) atIndex:9];
2251+
[encoder setBytes:&C length:sizeof(C) atIndex:10];
2252+
[encoder setBytes:&H length:sizeof(H) atIndex:11];
2253+
22112254
[encoder dispatchThreadgroups:MTLSizeMake(B * H, 1, 1) threadsPerThreadgroup:MTLSizeMake(C/ H, 1, 1)];
22122255
} break;
22132256
case GGML_OP_MUL_MAT:

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 98 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1453,10 +1453,10 @@ kernel void kernel_rwkv_wkv6_f32(
14531453
y += dot(r_vec, temp);
14541454

14551455
s_vec = s_vec * td_vec + kv;
1456-
state[j] = s_vec.x;
1457-
state[j+1] = s_vec.y;
1458-
state[j+2] = s_vec.z;
1459-
state[j+3] = s_vec.w;
1456+
state[j] = s_vec[0];
1457+
state[j+1] = s_vec[1];
1458+
state[j+2] = s_vec[2];
1459+
state[j+3] = s_vec[3];
14601460
}
14611461

14621462
dst[t] = y;
@@ -1468,6 +1468,100 @@ kernel void kernel_rwkv_wkv6_f32(
14681468
}
14691469
}
14701470

1471+
kernel void kernel_rwkv_wkv7_f32(
1472+
device const float * r,
1473+
device const float * w,
1474+
device const float * k,
1475+
device const float * v,
1476+
device const float * a,
1477+
device const float * b,
1478+
device const float * state_in,
1479+
device float * dst,
1480+
constant uint & B,
1481+
constant uint & T,
1482+
constant uint & C,
1483+
constant uint & H,
1484+
uint3 tgpig[[threadgroup_position_in_grid]],
1485+
uint3 tpitg[[thread_position_in_threadgroup]],
1486+
uint3 ntg[[threads_per_threadgroup]]) {
1487+
1488+
const uint head_size = 64;
1489+
const uint batch_id = tgpig.x / H;
1490+
const uint head_id = tgpig.x % H;
1491+
const uint tid = tpitg.x;
1492+
1493+
if (batch_id >= B || head_id >= H) {
1494+
return;
1495+
}
1496+
1497+
const uint state_size = C * head_size;
1498+
const uint n_seq_tokens = T / B;
1499+
1500+
threadgroup float _r[head_size];
1501+
threadgroup float _w[head_size];
1502+
threadgroup float _k[head_size];
1503+
threadgroup float _a[head_size];
1504+
threadgroup float _b[head_size];
1505+
1506+
float state[head_size];
1507+
#pragma unroll(64)
1508+
for (uint i = 0; i < head_size; i++) {
1509+
state[i] = state_in[batch_id * state_size + head_id * head_size * head_size
1510+
+ tid * head_size + i];
1511+
}
1512+
1513+
const uint start_t = batch_id * n_seq_tokens * C + head_id * head_size + tid;
1514+
const uint end_t = (batch_id + 1) * n_seq_tokens * C + head_id * head_size + tid;
1515+
1516+
for (uint t = start_t; t < end_t; t += C) {
1517+
threadgroup_barrier(mem_flags::mem_threadgroup);
1518+
_r[tid] = r[t];
1519+
_w[tid] = w[t];
1520+
_k[tid] = k[t];
1521+
_a[tid] = a[t];
1522+
_b[tid] = b[t];
1523+
threadgroup_barrier(mem_flags::mem_threadgroup);
1524+
1525+
const float v_val = v[t];
1526+
float y = 0.0, sa = 0.0;
1527+
1528+
float4 sa_vec(0.0);
1529+
#pragma unroll(64)
1530+
for (int j = 0; j < head_size; j += 4) {
1531+
float4 a_vec = float4(_a[j], _a[j+1], _a[j+2], _a[j+3]);
1532+
float4 s_vec = float4(state[j], state[j+1], state[j+2], state[j+3]);
1533+
sa_vec += a_vec * s_vec;
1534+
}
1535+
sa = sa_vec[0] + sa_vec[1] + sa_vec[2] + sa_vec[3];
1536+
1537+
#pragma unroll(64)
1538+
for (uint j = 0; j < head_size; j += 4) {
1539+
float4 r_vec = float4(_r[j], _r[j+1], _r[j+2], _r[j+3]);
1540+
float4 w_vec = float4(_w[j], _w[j+1], _w[j+2], _w[j+3]);
1541+
float4 k_vec = float4(_k[j], _k[j+1], _k[j+2], _k[j+3]);
1542+
float4 b_vec = float4(_b[j], _b[j+1], _b[j+2], _b[j+3]);
1543+
float4 s_vec = float4(state[j], state[j+1], state[j+2], state[j+3]);
1544+
1545+
float4 kv = k_vec * v_val;
1546+
1547+
s_vec = s_vec * w_vec + kv + sa * b_vec;
1548+
y += dot(s_vec, r_vec);
1549+
1550+
state[j] = s_vec[0];
1551+
state[j+1] = s_vec[1];
1552+
state[j+2] = s_vec[2];
1553+
state[j+3] = s_vec[3];
1554+
}
1555+
1556+
dst[t] = y;
1557+
}
1558+
#pragma unroll(64)
1559+
for (uint i = 0; i < head_size; i++) {
1560+
dst[T * C + batch_id * state_size + head_id * head_size * head_size
1561+
+ tid * head_size + i] = state[i];
1562+
}
1563+
}
1564+
14711565
kernel void kernel_argmax(
14721566
device const void * x,
14731567
device int32_t * dst,

0 commit comments

Comments
 (0)