@@ -171,6 +171,7 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
171
171
GGML_METAL_KERNEL_TYPE_NORM,
172
172
GGML_METAL_KERNEL_TYPE_SSM_CONV_F32,
173
173
GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32,
174
+ GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32,
174
175
GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32,
175
176
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32,
176
177
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW,
@@ -777,6 +778,7 @@ @implementation GGMLMetalClass
777
778
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_NORM, norm, true );
778
779
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, ssm_conv_f32, true );
779
780
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32, ssm_scan_f32, true );
781
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32, rwkv_wkv6_f32, true );
780
782
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, has_simdgroup_reduction);
781
783
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32, mul_mv_bf16_f32, has_simdgroup_reduction && use_bfloat);
782
784
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW, mul_mv_bf16_f32_1row, has_simdgroup_reduction && use_bfloat);
@@ -1237,6 +1239,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
1237
1239
return has_simdgroup_mm; // TODO: over-restricted for vec-kernels
1238
1240
case GGML_OP_SSM_CONV:
1239
1241
case GGML_OP_SSM_SCAN:
1242
+ case GGML_OP_RWKV_WKV6:
1240
1243
return true ;
1241
1244
case GGML_OP_MUL_MAT:
1242
1245
case GGML_OP_MUL_MAT_ID:
@@ -2140,6 +2143,57 @@ static void ggml_metal_encode_node(
2140
2143
2141
2144
[encoder dispatchThreadgroups: MTLSizeMake (d_inner, n_seqs, 1 ) threadsPerThreadgroup: MTLSizeMake (1 , 1 , 1 )];
2142
2145
} break ;
2146
+ case GGML_OP_RWKV_WKV6:
2147
+ {
2148
+ const int64_t B = dst->src [5 ]->ne [1 ];
2149
+ const int64_t T = dst->src [0 ]->ne [3 ];
2150
+ const int64_t C = dst->ne [0 ];
2151
+ const int64_t H = dst->src [0 ]->ne [2 ];
2152
+
2153
+ GGML_ASSERT (dst->src [5 ]->type == GGML_TYPE_F32);
2154
+ GGML_ASSERT (C % H == 0 );
2155
+ GGML_ASSERT (C / H == 64 ); // The current Metal kernel is designed for RWKV6, HEAD_SIZE == 64
2156
+
2157
+ size_t offs_k = 0 ;
2158
+ size_t offs_v = 0 ;
2159
+ size_t offs_r = 0 ;
2160
+ size_t offs_tf = 0 ;
2161
+ size_t offs_td = 0 ;
2162
+ size_t offs_s = 0 ;
2163
+ size_t offs_dst = 0 ;
2164
+
2165
+ id <MTLBuffer > id_k = dst->src [0 ] ? ggml_metal_get_buffer (dst->src [0 ], &offs_k) : nil ;
2166
+ id <MTLBuffer > id_v = dst->src [1 ] ? ggml_metal_get_buffer (dst->src [1 ], &offs_v) : nil ;
2167
+ id <MTLBuffer > id_r = dst->src [2 ] ? ggml_metal_get_buffer (dst->src [2 ], &offs_r) : nil ;
2168
+ id <MTLBuffer > id_tf = dst->src [3 ] ? ggml_metal_get_buffer (dst->src [3 ], &offs_tf) : nil ;
2169
+ id <MTLBuffer > id_td = dst->src [4 ] ? ggml_metal_get_buffer (dst->src [4 ], &offs_td) : nil ;
2170
+ id <MTLBuffer > id_s = dst->src [5 ] ? ggml_metal_get_buffer (dst->src [5 ], &offs_s) : nil ;
2171
+ id <MTLBuffer > id_dst = dst ? ggml_metal_get_buffer (dst, &offs_dst) : nil ;
2172
+
2173
+ id <MTLComputePipelineState > pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32].pipeline ;
2174
+
2175
+ id <MTLCommandBuffer > command_buffer = ctx->queue .commandBuffer ;
2176
+ id <MTLComputeCommandEncoder > encoder = [command_buffer computeCommandEncoder ];
2177
+
2178
+ [encoder setComputePipelineState: pipeline];
2179
+ [encoder setBuffer: id_k offset: offs_k atIndex: 0 ];
2180
+ [encoder setBuffer: id_v offset: offs_v atIndex: 1 ];
2181
+ [encoder setBuffer: id_r offset: offs_r atIndex: 2 ];
2182
+ [encoder setBuffer: id_tf offset: offs_tf atIndex: 3 ];
2183
+ [encoder setBuffer: id_td offset: offs_td atIndex: 4 ];
2184
+ [encoder setBuffer: id_s offset: offs_s atIndex: 5 ];
2185
+ [encoder setBuffer: id_dst offset: offs_dst atIndex: 6 ];
2186
+
2187
+ [encoder setBytes: &B length: sizeof (B) atIndex: 7 ];
2188
+ [encoder setBytes: &T length: sizeof (T) atIndex: 8 ];
2189
+ [encoder setBytes: &C length: sizeof (C) atIndex: 9 ];
2190
+ [encoder setBytes: &H length: sizeof (H) atIndex: 10 ];
2191
+
2192
+ [encoder dispatchThreadgroups: MTLSizeMake (B * H, 1 , 1 ) threadsPerThreadgroup: MTLSizeMake (C/ H, 1 , 1 )];
2193
+
2194
+ [encoder endEncoding ];
2195
+ [command_buffer commit ];
2196
+ } break ;
2143
2197
case GGML_OP_MUL_MAT:
2144
2198
{
2145
2199
GGML_ASSERT (ne00 == ne10);
0 commit comments