@@ -498,6 +498,7 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
498
498
GGML_METAL_KERNEL_TYPE_COS,
499
499
GGML_METAL_KERNEL_TYPE_NEG,
500
500
GGML_METAL_KERNEL_TYPE_SUM_ROWS,
501
+ GGML_METAL_KERNEL_TYPE_MEAN,
501
502
GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,
502
503
GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32,
503
504
GGML_METAL_KERNEL_TYPE_ARGMAX,
@@ -1454,6 +1455,7 @@ @implementation GGMLMetalClass
1454
1455
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_COS, cos, true );
1455
1456
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_NEG, neg, true );
1456
1457
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true );
1458
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MEAN, mean, true );
1457
1459
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ARGMAX, argmax, true );
1458
1460
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, pool_2d_avg_f32, true );
1459
1461
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32, pool_2d_max_f32, true );
@@ -1653,6 +1655,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
1653
1655
case GGML_OP_LOG:
1654
1656
return false ; // TODO: implement
1655
1657
case GGML_OP_SUM_ROWS:
1658
+ case GGML_OP_MEAN:
1656
1659
case GGML_OP_SOFT_MAX:
1657
1660
case GGML_OP_GROUP_NORM:
1658
1661
return has_simdgroup_reduction && ggml_is_contiguous (op->src [0 ]);
@@ -2400,11 +2403,30 @@ static bool ggml_metal_encode_node(
2400
2403
[encoder dispatchThreadgroups: MTLSizeMake (n, 1 , 1 ) threadsPerThreadgroup: MTLSizeMake (1 , 1 , 1 )];
2401
2404
} break ;
2402
2405
case GGML_OP_SUM_ROWS:
2406
+ case GGML_OP_MEAN:
2403
2407
{
2404
2408
GGML_ASSERT (src0->nb [0 ] == ggml_type_size (src0->type ));
2405
2409
2406
- id <MTLComputePipelineState > pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline ;
2410
+ id <MTLComputePipelineState > pipeline = nil ;
2411
+
2412
+ switch (dst->op ) {
2413
+ case GGML_OP_SUM_ROWS:
2414
+ pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline ;
2415
+ break ;
2416
+ case GGML_OP_MEAN:
2417
+ pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MEAN].pipeline ;
2418
+ break ;
2419
+ default :
2420
+ GGML_ABORT (" fatal error" );
2421
+ }
2422
+
2423
+ int nth = 32 ; // SIMD width
2424
+
2425
+ while (nth < ne00 && nth < (int ) pipeline.maxTotalThreadsPerThreadgroup ) {
2426
+ nth *= 2 ;
2427
+ }
2407
2428
2429
+ nth = MIN (nth, ne00);
2408
2430
2409
2431
ggml_metal_kargs_sum_rows args = {
2410
2432
/* .ne00 =*/ ne00,
@@ -2434,11 +2456,12 @@ static bool ggml_metal_encode_node(
2434
2456
};
2435
2457
2436
2458
[encoder setComputePipelineState: pipeline];
2437
- [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
2438
- [encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
2439
- [encoder setBytes: &args length: sizeof (args) atIndex: 2 ];
2459
+ [encoder setBytes: &args length: sizeof (args) atIndex: 0 ];
2460
+ [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 1 ];
2461
+ [encoder setBuffer: id_dst offset: offs_dst atIndex: 2 ];
2462
+ [encoder setThreadgroupMemoryLength: 32 *sizeof (float ) atIndex: 0 ];
2440
2463
2441
- [encoder dispatchThreadgroups: MTLSizeMake (ne01, ne02, ne03) threadsPerThreadgroup: MTLSizeMake (1 , 1 , 1 )];
2464
+ [encoder dispatchThreadgroups: MTLSizeMake (ne01, ne02, ne03) threadsPerThreadgroup: MTLSizeMake (nth , 1 , 1 )];
2442
2465
} break ;
2443
2466
case GGML_OP_SOFT_MAX:
2444
2467
{
0 commit comments