Skip to content

Commit d672ecc

Browse files
committed
Fix metal wkv6 inference
Signed-off-by: Molly Sophia <[email protected]>
1 parent 757ebb2 commit d672ecc

File tree

1 file changed

+17
-31
lines changed

1 file changed

+17
-31
lines changed

ggml/src/ggml-metal/ggml-metal.m

Lines changed: 17 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -2154,42 +2154,31 @@ static void ggml_metal_encode_node(
21542154
case GGML_OP_RWKV_WKV6:
21552155
{
21562156
const int64_t B = dst->src[5]->ne[1];
2157-
const int64_t T = dst->src[0]->ne[3];
2157+
const int64_t T = dst->src[0]->ne[2];
21582158
const int64_t C = dst->ne[0];
2159-
const int64_t H = dst->src[0]->ne[2];
2159+
const int64_t H = dst->src[0]->ne[1];
21602160

21612161
GGML_ASSERT(dst->src[5]->type == GGML_TYPE_F32);
21622162
GGML_ASSERT(C % H == 0);
2163-
GGML_ASSERT(C / H == 64); // The current Metal kernel is designed for RWKV6, HEAD_SIZE == 64
2164-
2165-
size_t offs_k = 0;
2166-
size_t offs_v = 0;
2167-
size_t offs_r = 0;
2168-
size_t offs_tf = 0;
2169-
size_t offs_td = 0;
2170-
size_t offs_s = 0;
2171-
size_t offs_dst = 0;
2172-
2173-
id<MTLBuffer> id_k = dst->src[0] ? ggml_metal_get_buffer(dst->src[0], &offs_k) : nil;
2174-
id<MTLBuffer> id_v = dst->src[1] ? ggml_metal_get_buffer(dst->src[1], &offs_v) : nil;
2175-
id<MTLBuffer> id_r = dst->src[2] ? ggml_metal_get_buffer(dst->src[2], &offs_r) : nil;
2176-
id<MTLBuffer> id_tf = dst->src[3] ? ggml_metal_get_buffer(dst->src[3], &offs_tf) : nil;
2177-
id<MTLBuffer> id_td = dst->src[4] ? ggml_metal_get_buffer(dst->src[4], &offs_td) : nil;
2178-
id<MTLBuffer> id_s = dst->src[5] ? ggml_metal_get_buffer(dst->src[5], &offs_s) : nil;
2179-
id<MTLBuffer> id_dst = dst ? ggml_metal_get_buffer(dst, &offs_dst) : nil;
2163+
GGML_ASSERT(C / H == 64);
21802164

2181-
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32].pipeline;
2165+
size_t offs_src3 = 0;
2166+
size_t offs_src4 = 0;
2167+
size_t offs_src5 = 0;
21822168

2183-
id<MTLCommandBuffer> command_buffer = ctx->queue.commandBuffer;
2184-
id<MTLComputeCommandEncoder> encoder = [command_buffer computeCommandEncoder];
2169+
id<MTLBuffer> id_src3 = dst->src[3] ? ggml_metal_get_buffer(dst->src[3], &offs_src3) : nil;
2170+
id<MTLBuffer> id_src4 = dst->src[4] ? ggml_metal_get_buffer(dst->src[4], &offs_src4) : nil;
2171+
id<MTLBuffer> id_src5 = dst->src[5] ? ggml_metal_get_buffer(dst->src[5], &offs_src5) : nil;
2172+
2173+
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32].pipeline;
21852174

21862175
[encoder setComputePipelineState:pipeline];
2187-
[encoder setBuffer:id_k offset:offs_k atIndex:0];
2188-
[encoder setBuffer:id_v offset:offs_v atIndex:1];
2189-
[encoder setBuffer:id_r offset:offs_r atIndex:2];
2190-
[encoder setBuffer:id_tf offset:offs_tf atIndex:3];
2191-
[encoder setBuffer:id_td offset:offs_td atIndex:4];
2192-
[encoder setBuffer:id_s offset:offs_s atIndex:5];
2176+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2177+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
2178+
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
2179+
[encoder setBuffer:id_src3 offset:offs_src3 atIndex:3];
2180+
[encoder setBuffer:id_src4 offset:offs_src4 atIndex:4];
2181+
[encoder setBuffer:id_src5 offset:offs_src5 atIndex:5];
21932182
[encoder setBuffer:id_dst offset:offs_dst atIndex:6];
21942183

21952184
[encoder setBytes:&B length:sizeof(B) atIndex:7];
@@ -2198,9 +2187,6 @@ static void ggml_metal_encode_node(
21982187
[encoder setBytes:&H length:sizeof(H) atIndex:10];
21992188

22002189
[encoder dispatchThreadgroups:MTLSizeMake(B * H, 1, 1) threadsPerThreadgroup:MTLSizeMake(C/ H, 1, 1)];
2201-
2202-
[encoder endEncoding];
2203-
[command_buffer commit];
22042190
} break;
22052191
case GGML_OP_MUL_MAT:
22062192
{

0 commit comments

Comments
 (0)