Skip to content

Commit 5f4dc3e

Browse files
committed
Fix metal wkv6 inference
Signed-off-by: Molly Sophia <[email protected]>
1 parent 694b5d1 commit 5f4dc3e

File tree

1 file changed

+17
-31
lines changed

1 file changed

+17
-31
lines changed

ggml/src/ggml-metal/ggml-metal.m

Lines changed: 17 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -2146,42 +2146,31 @@ static void ggml_metal_encode_node(
21462146
case GGML_OP_RWKV_WKV6:
21472147
{
21482148
const int64_t B = dst->src[5]->ne[1];
2149-
const int64_t T = dst->src[0]->ne[3];
2149+
const int64_t T = dst->src[0]->ne[2];
21502150
const int64_t C = dst->ne[0];
2151-
const int64_t H = dst->src[0]->ne[2];
2151+
const int64_t H = dst->src[0]->ne[1];
21522152

21532153
GGML_ASSERT(dst->src[5]->type == GGML_TYPE_F32);
21542154
GGML_ASSERT(C % H == 0);
2155-
GGML_ASSERT(C / H == 64); // The current Metal kernel is designed for RWKV6, HEAD_SIZE == 64
2156-
2157-
size_t offs_k = 0;
2158-
size_t offs_v = 0;
2159-
size_t offs_r = 0;
2160-
size_t offs_tf = 0;
2161-
size_t offs_td = 0;
2162-
size_t offs_s = 0;
2163-
size_t offs_dst = 0;
2164-
2165-
id<MTLBuffer> id_k = dst->src[0] ? ggml_metal_get_buffer(dst->src[0], &offs_k) : nil;
2166-
id<MTLBuffer> id_v = dst->src[1] ? ggml_metal_get_buffer(dst->src[1], &offs_v) : nil;
2167-
id<MTLBuffer> id_r = dst->src[2] ? ggml_metal_get_buffer(dst->src[2], &offs_r) : nil;
2168-
id<MTLBuffer> id_tf = dst->src[3] ? ggml_metal_get_buffer(dst->src[3], &offs_tf) : nil;
2169-
id<MTLBuffer> id_td = dst->src[4] ? ggml_metal_get_buffer(dst->src[4], &offs_td) : nil;
2170-
id<MTLBuffer> id_s = dst->src[5] ? ggml_metal_get_buffer(dst->src[5], &offs_s) : nil;
2171-
id<MTLBuffer> id_dst = dst ? ggml_metal_get_buffer(dst, &offs_dst) : nil;
2155+
GGML_ASSERT(C / H == 64);
21722156

2173-
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32].pipeline;
2157+
size_t offs_src3 = 0;
2158+
size_t offs_src4 = 0;
2159+
size_t offs_src5 = 0;
21742160

2175-
id<MTLCommandBuffer> command_buffer = ctx->queue.commandBuffer;
2176-
id<MTLComputeCommandEncoder> encoder = [command_buffer computeCommandEncoder];
2161+
id<MTLBuffer> id_src3 = dst->src[3] ? ggml_metal_get_buffer(dst->src[3], &offs_src3) : nil;
2162+
id<MTLBuffer> id_src4 = dst->src[4] ? ggml_metal_get_buffer(dst->src[4], &offs_src4) : nil;
2163+
id<MTLBuffer> id_src5 = dst->src[5] ? ggml_metal_get_buffer(dst->src[5], &offs_src5) : nil;
2164+
2165+
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32].pipeline;
21772166

21782167
[encoder setComputePipelineState:pipeline];
2179-
[encoder setBuffer:id_k offset:offs_k atIndex:0];
2180-
[encoder setBuffer:id_v offset:offs_v atIndex:1];
2181-
[encoder setBuffer:id_r offset:offs_r atIndex:2];
2182-
[encoder setBuffer:id_tf offset:offs_tf atIndex:3];
2183-
[encoder setBuffer:id_td offset:offs_td atIndex:4];
2184-
[encoder setBuffer:id_s offset:offs_s atIndex:5];
2168+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2169+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
2170+
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
2171+
[encoder setBuffer:id_src3 offset:offs_src3 atIndex:3];
2172+
[encoder setBuffer:id_src4 offset:offs_src4 atIndex:4];
2173+
[encoder setBuffer:id_src5 offset:offs_src5 atIndex:5];
21852174
[encoder setBuffer:id_dst offset:offs_dst atIndex:6];
21862175

21872176
[encoder setBytes:&B length:sizeof(B) atIndex:7];
@@ -2190,9 +2179,6 @@ static void ggml_metal_encode_node(
21902179
[encoder setBytes:&H length:sizeof(H) atIndex:10];
21912180

21922181
[encoder dispatchThreadgroups:MTLSizeMake(B * H, 1, 1) threadsPerThreadgroup:MTLSizeMake(C/ H, 1, 1)];
2193-
2194-
[encoder endEncoding];
2195-
[command_buffer commit];
21962182
} break;
21972183
case GGML_OP_MUL_MAT:
21982184
{

0 commit comments

Comments
 (0)