@@ -178,6 +178,7 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
178
178
GGML_METAL_KERNEL_TYPE_NORM,
179
179
GGML_METAL_KERNEL_TYPE_SSM_CONV_F32,
180
180
GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32,
181
+ GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32,
181
182
GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32,
182
183
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32,
183
184
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW,
@@ -784,6 +785,7 @@ @implementation GGMLMetalClass
784
785
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_NORM, norm, true );
785
786
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, ssm_conv_f32, true );
786
787
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32, ssm_scan_f32, true );
788
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32, rwkv_wkv6_f32, true );
787
789
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, has_simdgroup_reduction);
788
790
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32, mul_mv_bf16_f32, has_simdgroup_reduction && use_bfloat);
789
791
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW, mul_mv_bf16_f32_1row, has_simdgroup_reduction && use_bfloat);
@@ -1245,6 +1247,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
1245
1247
return has_simdgroup_mm; // TODO: over-restricted for vec-kernels
1246
1248
case GGML_OP_SSM_CONV:
1247
1249
case GGML_OP_SSM_SCAN:
1250
+ case GGML_OP_RWKV_WKV6:
1248
1251
return true ;
1249
1252
case GGML_OP_MUL_MAT:
1250
1253
case GGML_OP_MUL_MAT_ID:
@@ -2148,6 +2151,57 @@ static void ggml_metal_encode_node(
2148
2151
2149
2152
[encoder dispatchThreadgroups: MTLSizeMake (d_inner, n_seqs, 1 ) threadsPerThreadgroup: MTLSizeMake (1 , 1 , 1 )];
2150
2153
} break ;
2154
+ case GGML_OP_RWKV_WKV6:
2155
+ {
2156
+ const int64_t B = dst->src [5 ]->ne [1 ];
2157
+ const int64_t T = dst->src [0 ]->ne [3 ];
2158
+ const int64_t C = dst->ne [0 ];
2159
+ const int64_t H = dst->src [0 ]->ne [2 ];
2160
+
2161
+ GGML_ASSERT (dst->src [5 ]->type == GGML_TYPE_F32);
2162
+ GGML_ASSERT (C % H == 0 );
2163
+ GGML_ASSERT (C / H == 64 ); // The current Metal kernel is designed for RWKV6, HEAD_SIZE == 64
2164
+
2165
+ size_t offs_k = 0 ;
2166
+ size_t offs_v = 0 ;
2167
+ size_t offs_r = 0 ;
2168
+ size_t offs_tf = 0 ;
2169
+ size_t offs_td = 0 ;
2170
+ size_t offs_s = 0 ;
2171
+ size_t offs_dst = 0 ;
2172
+
2173
+ id <MTLBuffer > id_k = dst->src [0 ] ? ggml_metal_get_buffer (dst->src [0 ], &offs_k) : nil ;
2174
+ id <MTLBuffer > id_v = dst->src [1 ] ? ggml_metal_get_buffer (dst->src [1 ], &offs_v) : nil ;
2175
+ id <MTLBuffer > id_r = dst->src [2 ] ? ggml_metal_get_buffer (dst->src [2 ], &offs_r) : nil ;
2176
+ id <MTLBuffer > id_tf = dst->src [3 ] ? ggml_metal_get_buffer (dst->src [3 ], &offs_tf) : nil ;
2177
+ id <MTLBuffer > id_td = dst->src [4 ] ? ggml_metal_get_buffer (dst->src [4 ], &offs_td) : nil ;
2178
+ id <MTLBuffer > id_s = dst->src [5 ] ? ggml_metal_get_buffer (dst->src [5 ], &offs_s) : nil ;
2179
+ id <MTLBuffer > id_dst = dst ? ggml_metal_get_buffer (dst, &offs_dst) : nil ;
2180
+
2181
+ id <MTLComputePipelineState > pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32].pipeline ;
2182
+
2183
+ id <MTLCommandBuffer > command_buffer = ctx->queue .commandBuffer ;
2184
+ id <MTLComputeCommandEncoder > encoder = [command_buffer computeCommandEncoder ];
2185
+
2186
+ [encoder setComputePipelineState: pipeline];
2187
+ [encoder setBuffer: id_k offset: offs_k atIndex: 0 ];
2188
+ [encoder setBuffer: id_v offset: offs_v atIndex: 1 ];
2189
+ [encoder setBuffer: id_r offset: offs_r atIndex: 2 ];
2190
+ [encoder setBuffer: id_tf offset: offs_tf atIndex: 3 ];
2191
+ [encoder setBuffer: id_td offset: offs_td atIndex: 4 ];
2192
+ [encoder setBuffer: id_s offset: offs_s atIndex: 5 ];
2193
+ [encoder setBuffer: id_dst offset: offs_dst atIndex: 6 ];
2194
+
2195
+ [encoder setBytes: &B length: sizeof (B) atIndex: 7 ];
2196
+ [encoder setBytes: &T length: sizeof (T) atIndex: 8 ];
2197
+ [encoder setBytes: &C length: sizeof (C) atIndex: 9 ];
2198
+ [encoder setBytes: &H length: sizeof (H) atIndex: 10 ];
2199
+
2200
+ [encoder dispatchThreadgroups: MTLSizeMake (B * H, 1 , 1 ) threadsPerThreadgroup: MTLSizeMake (C/ H, 1 , 1 )];
2201
+
2202
+ [encoder endEncoding ];
2203
+ [command_buffer commit ];
2204
+ } break ;
2151
2205
case GGML_OP_MUL_MAT:
2152
2206
{
2153
2207
GGML_ASSERT (ne00 == ne10);
0 commit comments