@@ -2146,42 +2146,31 @@ static void ggml_metal_encode_node(
2146
2146
case GGML_OP_RWKV_WKV6:
2147
2147
{
2148
2148
const int64_t B = dst->src [5 ]->ne [1 ];
2149
- const int64_t T = dst->src [0 ]->ne [3 ];
2149
+ const int64_t T = dst->src [0 ]->ne [2 ];
2150
2150
const int64_t C = dst->ne [0 ];
2151
- const int64_t H = dst->src [0 ]->ne [2 ];
2151
+ const int64_t H = dst->src [0 ]->ne [1 ];
2152
2152
2153
2153
GGML_ASSERT (dst->src [5 ]->type == GGML_TYPE_F32);
2154
2154
GGML_ASSERT (C % H == 0 );
2155
- GGML_ASSERT (C / H == 64 ); // The current Metal kernel is designed for RWKV6, HEAD_SIZE == 64
2156
-
2157
- size_t offs_k = 0 ;
2158
- size_t offs_v = 0 ;
2159
- size_t offs_r = 0 ;
2160
- size_t offs_tf = 0 ;
2161
- size_t offs_td = 0 ;
2162
- size_t offs_s = 0 ;
2163
- size_t offs_dst = 0 ;
2164
-
2165
- id <MTLBuffer > id_k = dst->src [0 ] ? ggml_metal_get_buffer (dst->src [0 ], &offs_k) : nil ;
2166
- id <MTLBuffer > id_v = dst->src [1 ] ? ggml_metal_get_buffer (dst->src [1 ], &offs_v) : nil ;
2167
- id <MTLBuffer > id_r = dst->src [2 ] ? ggml_metal_get_buffer (dst->src [2 ], &offs_r) : nil ;
2168
- id <MTLBuffer > id_tf = dst->src [3 ] ? ggml_metal_get_buffer (dst->src [3 ], &offs_tf) : nil ;
2169
- id <MTLBuffer > id_td = dst->src [4 ] ? ggml_metal_get_buffer (dst->src [4 ], &offs_td) : nil ;
2170
- id <MTLBuffer > id_s = dst->src [5 ] ? ggml_metal_get_buffer (dst->src [5 ], &offs_s) : nil ;
2171
- id <MTLBuffer > id_dst = dst ? ggml_metal_get_buffer (dst, &offs_dst) : nil ;
2155
+ GGML_ASSERT (C / H == 64 );
2172
2156
2173
- id <MTLComputePipelineState > pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32].pipeline ;
2157
+ size_t offs_src3 = 0 ;
2158
+ size_t offs_src4 = 0 ;
2159
+ size_t offs_src5 = 0 ;
2174
2160
2175
- id <MTLCommandBuffer > command_buffer = ctx->queue .commandBuffer ;
2176
- id <MTLComputeCommandEncoder > encoder = [command_buffer computeCommandEncoder ];
2161
+ id <MTLBuffer > id_src3 = dst->src [3 ] ? ggml_metal_get_buffer (dst->src [3 ], &offs_src3) : nil ;
2162
+ id <MTLBuffer > id_src4 = dst->src [4 ] ? ggml_metal_get_buffer (dst->src [4 ], &offs_src4) : nil ;
2163
+ id <MTLBuffer > id_src5 = dst->src [5 ] ? ggml_metal_get_buffer (dst->src [5 ], &offs_src5) : nil ;
2164
+
2165
+ id <MTLComputePipelineState > pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32].pipeline ;
2177
2166
2178
2167
[encoder setComputePipelineState: pipeline];
2179
- [encoder setBuffer: id_k offset: offs_k atIndex: 0 ];
2180
- [encoder setBuffer: id_v offset: offs_v atIndex: 1 ];
2181
- [encoder setBuffer: id_r offset: offs_r atIndex: 2 ];
2182
- [encoder setBuffer: id_tf offset: offs_tf atIndex: 3 ];
2183
- [encoder setBuffer: id_td offset: offs_td atIndex: 4 ];
2184
- [encoder setBuffer: id_s offset: offs_s atIndex: 5 ];
2168
+ [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
2169
+ [encoder setBuffer: id_src1 offset: offs_src1 atIndex: 1 ];
2170
+ [encoder setBuffer: id_src2 offset: offs_src2 atIndex: 2 ];
2171
+ [encoder setBuffer: id_src3 offset: offs_src3 atIndex: 3 ];
2172
+ [encoder setBuffer: id_src4 offset: offs_src4 atIndex: 4 ];
2173
+ [encoder setBuffer: id_src5 offset: offs_src5 atIndex: 5 ];
2185
2174
[encoder setBuffer: id_dst offset: offs_dst atIndex: 6 ];
2186
2175
2187
2176
[encoder setBytes: &B length: sizeof (B) atIndex: 7 ];
@@ -2190,9 +2179,6 @@ static void ggml_metal_encode_node(
2190
2179
[encoder setBytes: &H length: sizeof (H) atIndex: 10 ];
2191
2180
2192
2181
[encoder dispatchThreadgroups: MTLSizeMake (B * H, 1 , 1 ) threadsPerThreadgroup: MTLSizeMake (C/ H, 1 , 1 )];
2193
-
2194
- [encoder endEncoding ];
2195
- [command_buffer commit ];
2196
2182
} break ;
2197
2183
case GGML_OP_MUL_MAT:
2198
2184
{
0 commit comments