@@ -2154,42 +2154,31 @@ static void ggml_metal_encode_node(
2154
2154
case GGML_OP_RWKV_WKV6:
2155
2155
{
2156
2156
const int64_t B = dst->src [5 ]->ne [1 ];
2157
- const int64_t T = dst->src [0 ]->ne [3 ];
2157
+ const int64_t T = dst->src [0 ]->ne [2 ];
2158
2158
const int64_t C = dst->ne [0 ];
2159
- const int64_t H = dst->src [0 ]->ne [2 ];
2159
+ const int64_t H = dst->src [0 ]->ne [1 ];
2160
2160
2161
2161
GGML_ASSERT (dst->src [5 ]->type == GGML_TYPE_F32);
2162
2162
GGML_ASSERT (C % H == 0 );
2163
- GGML_ASSERT (C / H == 64 ); // The current Metal kernel is designed for RWKV6, HEAD_SIZE == 64
2164
-
2165
- size_t offs_k = 0 ;
2166
- size_t offs_v = 0 ;
2167
- size_t offs_r = 0 ;
2168
- size_t offs_tf = 0 ;
2169
- size_t offs_td = 0 ;
2170
- size_t offs_s = 0 ;
2171
- size_t offs_dst = 0 ;
2172
-
2173
- id <MTLBuffer > id_k = dst->src [0 ] ? ggml_metal_get_buffer (dst->src [0 ], &offs_k) : nil ;
2174
- id <MTLBuffer > id_v = dst->src [1 ] ? ggml_metal_get_buffer (dst->src [1 ], &offs_v) : nil ;
2175
- id <MTLBuffer > id_r = dst->src [2 ] ? ggml_metal_get_buffer (dst->src [2 ], &offs_r) : nil ;
2176
- id <MTLBuffer > id_tf = dst->src [3 ] ? ggml_metal_get_buffer (dst->src [3 ], &offs_tf) : nil ;
2177
- id <MTLBuffer > id_td = dst->src [4 ] ? ggml_metal_get_buffer (dst->src [4 ], &offs_td) : nil ;
2178
- id <MTLBuffer > id_s = dst->src [5 ] ? ggml_metal_get_buffer (dst->src [5 ], &offs_s) : nil ;
2179
- id <MTLBuffer > id_dst = dst ? ggml_metal_get_buffer (dst, &offs_dst) : nil ;
2163
+ GGML_ASSERT (C / H == 64 );
2180
2164
2181
- id <MTLComputePipelineState > pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32].pipeline ;
2165
+ size_t offs_src3 = 0 ;
2166
+ size_t offs_src4 = 0 ;
2167
+ size_t offs_src5 = 0 ;
2182
2168
2183
- id <MTLCommandBuffer > command_buffer = ctx->queue .commandBuffer ;
2184
- id <MTLComputeCommandEncoder > encoder = [command_buffer computeCommandEncoder ];
2169
+ id <MTLBuffer > id_src3 = dst->src [3 ] ? ggml_metal_get_buffer (dst->src [3 ], &offs_src3) : nil ;
2170
+ id <MTLBuffer > id_src4 = dst->src [4 ] ? ggml_metal_get_buffer (dst->src [4 ], &offs_src4) : nil ;
2171
+ id <MTLBuffer > id_src5 = dst->src [5 ] ? ggml_metal_get_buffer (dst->src [5 ], &offs_src5) : nil ;
2172
+
2173
+ id <MTLComputePipelineState > pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32].pipeline ;
2185
2174
2186
2175
[encoder setComputePipelineState: pipeline];
2187
- [encoder setBuffer: id_k offset: offs_k atIndex: 0 ];
2188
- [encoder setBuffer: id_v offset: offs_v atIndex: 1 ];
2189
- [encoder setBuffer: id_r offset: offs_r atIndex: 2 ];
2190
- [encoder setBuffer: id_tf offset: offs_tf atIndex: 3 ];
2191
- [encoder setBuffer: id_td offset: offs_td atIndex: 4 ];
2192
- [encoder setBuffer: id_s offset: offs_s atIndex: 5 ];
2176
+ [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
2177
+ [encoder setBuffer: id_src1 offset: offs_src1 atIndex: 1 ];
2178
+ [encoder setBuffer: id_src2 offset: offs_src2 atIndex: 2 ];
2179
+ [encoder setBuffer: id_src3 offset: offs_src3 atIndex: 3 ];
2180
+ [encoder setBuffer: id_src4 offset: offs_src4 atIndex: 4 ];
2181
+ [encoder setBuffer: id_src5 offset: offs_src5 atIndex: 5 ];
2193
2182
[encoder setBuffer: id_dst offset: offs_dst atIndex: 6 ];
2194
2183
2195
2184
[encoder setBytes: &B length: sizeof (B) atIndex: 7 ];
@@ -2198,9 +2187,6 @@ static void ggml_metal_encode_node(
2198
2187
[encoder setBytes: &H length: sizeof (H) atIndex: 10 ];
2199
2188
2200
2189
[encoder dispatchThreadgroups: MTLSizeMake (B * H, 1 , 1 ) threadsPerThreadgroup: MTLSizeMake (C/ H, 1 , 1 )];
2201
-
2202
- [encoder endEncoding ];
2203
- [command_buffer commit ];
2204
2190
} break ;
2205
2191
case GGML_OP_MUL_MAT:
2206
2192
{
0 commit comments