@@ -174,6 +174,7 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
174
174
GGML_METAL_KERNEL_TYPE_SSM_CONV_F32,
175
175
GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32,
176
176
GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32,
177
+ GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32,
177
178
GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32,
178
179
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32,
179
180
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW,
@@ -783,6 +784,7 @@ @implementation GGMLMetalClass
783
784
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, ssm_conv_f32, true );
784
785
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32, ssm_scan_f32, true );
785
786
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32, rwkv_wkv6_f32, true );
787
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32, rwkv_wkv7_f32, true );
786
788
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, has_simdgroup_reduction);
787
789
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32, mul_mv_bf16_f32, has_simdgroup_reduction && use_bfloat);
788
790
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW, mul_mv_bf16_f32_1row, has_simdgroup_reduction && use_bfloat);
@@ -1246,6 +1248,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
1246
1248
case GGML_OP_SSM_CONV:
1247
1249
case GGML_OP_SSM_SCAN:
1248
1250
case GGML_OP_RWKV_WKV6:
1251
+ case GGML_OP_RWKV_WKV7:
1249
1252
return true ;
1250
1253
case GGML_OP_MUL_MAT:
1251
1254
case GGML_OP_MUL_MAT_ID:
@@ -2208,6 +2211,46 @@ static void ggml_metal_encode_node(
2208
2211
[encoder setBytes: &C length: sizeof (C) atIndex: 9 ];
2209
2212
[encoder setBytes: &H length: sizeof (H) atIndex: 10 ];
2210
2213
2214
+ [encoder dispatchThreadgroups: MTLSizeMake (B * H, 1 , 1 ) threadsPerThreadgroup: MTLSizeMake (C/ H, 1 , 1 )];
2215
+ } break ;
2216
+ case GGML_OP_RWKV_WKV7:
2217
+ {
2218
+ const int64_t B = dst->src [6 ]->ne [1 ];
2219
+ const int64_t T = dst->src [0 ]->ne [2 ];
2220
+ const int64_t C = dst->ne [0 ];
2221
+ const int64_t H = dst->src [0 ]->ne [1 ];
2222
+
2223
+ GGML_ASSERT (dst->src [6 ]->type == GGML_TYPE_F32);
2224
+ GGML_ASSERT (C % H == 0 );
2225
+ GGML_ASSERT (C / H == 64 );
2226
+
2227
+ size_t offs_src3 = 0 ;
2228
+ size_t offs_src4 = 0 ;
2229
+ size_t offs_src5 = 0 ;
2230
+ size_t offs_src6 = 0 ;
2231
+
2232
+ id <MTLBuffer > id_src3 = dst->src [3 ] ? ggml_metal_get_buffer (dst->src [3 ], &offs_src3) : nil ;
2233
+ id <MTLBuffer > id_src4 = dst->src [4 ] ? ggml_metal_get_buffer (dst->src [4 ], &offs_src4) : nil ;
2234
+ id <MTLBuffer > id_src5 = dst->src [5 ] ? ggml_metal_get_buffer (dst->src [5 ], &offs_src5) : nil ;
2235
+ id <MTLBuffer > id_src6 = dst->src [6 ] ? ggml_metal_get_buffer (dst->src [6 ], &offs_src6) : nil ;
2236
+
2237
+ id <MTLComputePipelineState > pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32].pipeline ;
2238
+
2239
+ [encoder setComputePipelineState: pipeline];
2240
+ [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
2241
+ [encoder setBuffer: id_src1 offset: offs_src1 atIndex: 1 ];
2242
+ [encoder setBuffer: id_src2 offset: offs_src2 atIndex: 2 ];
2243
+ [encoder setBuffer: id_src3 offset: offs_src3 atIndex: 3 ];
2244
+ [encoder setBuffer: id_src4 offset: offs_src4 atIndex: 4 ];
2245
+ [encoder setBuffer: id_src5 offset: offs_src5 atIndex: 5 ];
2246
+ [encoder setBuffer: id_src6 offset: offs_src6 atIndex: 6 ];
2247
+ [encoder setBuffer: id_dst offset: offs_dst atIndex: 7 ];
2248
+
2249
+ [encoder setBytes: &B length: sizeof (B) atIndex: 8 ];
2250
+ [encoder setBytes: &T length: sizeof (T) atIndex: 9 ];
2251
+ [encoder setBytes: &C length: sizeof (C) atIndex: 10 ];
2252
+ [encoder setBytes: &H length: sizeof (H) atIndex: 11 ];
2253
+
2211
2254
[encoder dispatchThreadgroups: MTLSizeMake (B * H, 1 , 1 ) threadsPerThreadgroup: MTLSizeMake (C/ H, 1 , 1 )];
2212
2255
} break ;
2213
2256
case GGML_OP_MUL_MAT:
0 commit comments