@@ -181,6 +181,7 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
181
181
GGML_METAL_KERNEL_TYPE_SSM_CONV_F32,
182
182
GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32,
183
183
GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32,
184
+ GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32,
184
185
GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32,
185
186
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32,
186
187
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW,
@@ -790,6 +791,7 @@ @implementation GGMLMetalClass
790
791
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, ssm_conv_f32, true );
791
792
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32, ssm_scan_f32, true );
792
793
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32, rwkv_wkv6_f32, true );
794
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32, rwkv_wkv7_f32, true );
793
795
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, has_simdgroup_reduction);
794
796
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32, mul_mv_bf16_f32, has_simdgroup_reduction && use_bfloat);
795
797
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW, mul_mv_bf16_f32_1row, has_simdgroup_reduction && use_bfloat);
@@ -1254,6 +1256,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
1254
1256
case GGML_OP_SSM_CONV:
1255
1257
case GGML_OP_SSM_SCAN:
1256
1258
case GGML_OP_RWKV_WKV6:
1259
+ case GGML_OP_RWKV_WKV7:
1257
1260
return true ;
1258
1261
case GGML_OP_MUL_MAT:
1259
1262
case GGML_OP_MUL_MAT_ID:
@@ -2216,6 +2219,46 @@ static void ggml_metal_encode_node(
2216
2219
[encoder setBytes: &C length: sizeof (C) atIndex: 9 ];
2217
2220
[encoder setBytes: &H length: sizeof (H) atIndex: 10 ];
2218
2221
2222
+ [encoder dispatchThreadgroups: MTLSizeMake (B * H, 1 , 1 ) threadsPerThreadgroup: MTLSizeMake (C/ H, 1 , 1 )];
2223
+ } break ;
2224
+ case GGML_OP_RWKV_WKV7:
2225
+ {
2226
+ const int64_t B = dst->src [6 ]->ne [1 ];
2227
+ const int64_t T = dst->src [0 ]->ne [2 ];
2228
+ const int64_t C = dst->ne [0 ];
2229
+ const int64_t H = dst->src [0 ]->ne [1 ];
2230
+
2231
+ GGML_ASSERT (dst->src [6 ]->type == GGML_TYPE_F32);
2232
+ GGML_ASSERT (C % H == 0 );
2233
+ GGML_ASSERT (C / H == 64 );
2234
+
2235
+ size_t offs_src3 = 0 ;
2236
+ size_t offs_src4 = 0 ;
2237
+ size_t offs_src5 = 0 ;
2238
+ size_t offs_src6 = 0 ;
2239
+
2240
+ id <MTLBuffer > id_src3 = dst->src [3 ] ? ggml_metal_get_buffer (dst->src [3 ], &offs_src3) : nil ;
2241
+ id <MTLBuffer > id_src4 = dst->src [4 ] ? ggml_metal_get_buffer (dst->src [4 ], &offs_src4) : nil ;
2242
+ id <MTLBuffer > id_src5 = dst->src [5 ] ? ggml_metal_get_buffer (dst->src [5 ], &offs_src5) : nil ;
2243
+ id <MTLBuffer > id_src6 = dst->src [6 ] ? ggml_metal_get_buffer (dst->src [6 ], &offs_src6) : nil ;
2244
+
2245
+ id <MTLComputePipelineState > pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32].pipeline ;
2246
+
2247
+ [encoder setComputePipelineState: pipeline];
2248
+ [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
2249
+ [encoder setBuffer: id_src1 offset: offs_src1 atIndex: 1 ];
2250
+ [encoder setBuffer: id_src2 offset: offs_src2 atIndex: 2 ];
2251
+ [encoder setBuffer: id_src3 offset: offs_src3 atIndex: 3 ];
2252
+ [encoder setBuffer: id_src4 offset: offs_src4 atIndex: 4 ];
2253
+ [encoder setBuffer: id_src5 offset: offs_src5 atIndex: 5 ];
2254
+ [encoder setBuffer: id_src6 offset: offs_src6 atIndex: 6 ];
2255
+ [encoder setBuffer: id_dst offset: offs_dst atIndex: 7 ];
2256
+
2257
+ [encoder setBytes: &B length: sizeof (B) atIndex: 8 ];
2258
+ [encoder setBytes: &T length: sizeof (T) atIndex: 9 ];
2259
+ [encoder setBytes: &C length: sizeof (C) atIndex: 10 ];
2260
+ [encoder setBytes: &H length: sizeof (H) atIndex: 11 ];
2261
+
2219
2262
[encoder dispatchThreadgroups: MTLSizeMake (B * H, 1 , 1 ) threadsPerThreadgroup: MTLSizeMake (C/ H, 1 , 1 )];
2220
2263
} break ;
2221
2264
case GGML_OP_MUL_MAT:
0 commit comments