@@ -241,6 +241,8 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
241
241
GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16,
242
242
GGML_METAL_KERNEL_TYPE_IM2COL_F16,
243
243
GGML_METAL_KERNEL_TYPE_IM2COL_F32,
244
+ GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16,
245
+ GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F32,
244
246
GGML_METAL_KERNEL_TYPE_UPSCALE_F32,
245
247
GGML_METAL_KERNEL_TYPE_PAD_F32,
246
248
GGML_METAL_KERNEL_TYPE_ARANGE_F32,
@@ -272,6 +274,8 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
272
274
GGML_METAL_KERNEL_TYPE_SIN,
273
275
GGML_METAL_KERNEL_TYPE_COS,
274
276
GGML_METAL_KERNEL_TYPE_SUM_ROWS,
277
+ GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,
278
+ GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32,
275
279
276
280
GGML_METAL_KERNEL_TYPE_COUNT
277
281
};
@@ -685,6 +689,8 @@ @implementation GGMLMetalClass
685
689
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16, rope_neox_f16, true );
686
690
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true );
687
691
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true );
692
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16, im2col_ext_f16, true );
693
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F32, im2col_ext_f32, true );
688
694
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true );
689
695
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_PAD_F32, pad_f32, true );
690
696
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32, timestep_embedding_f32, true );
@@ -716,6 +722,8 @@ @implementation GGMLMetalClass
716
722
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SIN, sin, true );
717
723
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_COS, cos, true );
718
724
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true );
725
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, pool_2d_avg_f32, true );
726
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32, pool_2d_max_f32, true );
719
727
}
720
728
721
729
[metal_library release ];
@@ -844,8 +852,8 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
844
852
case GGML_OP_IM2COL:
845
853
return op->src [0 ]->type == GGML_TYPE_F16;
846
854
case GGML_OP_POOL_1D:
847
- case GGML_OP_POOL_2D:
848
855
return false ;
856
+ case GGML_OP_POOL_2D:
849
857
case GGML_OP_UPSCALE:
850
858
case GGML_OP_PAD:
851
859
case GGML_OP_ARANGE:
@@ -2545,6 +2553,8 @@ static void ggml_metal_encode_node(
2545
2553
} break ;
2546
2554
case GGML_OP_IM2COL:
2547
2555
{
2556
+ GGML_ASSERT (ggml_is_contiguous (src0));
2557
+ GGML_ASSERT (ggml_is_contiguous (src1));
2548
2558
GGML_ASSERT (src0->type == GGML_TYPE_F16);
2549
2559
GGML_ASSERT (src1->type == GGML_TYPE_F32);
2550
2560
GGML_ASSERT ( dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32);
@@ -2574,30 +2584,54 @@ static void ggml_metal_encode_node(
2574
2584
const int32_t ofs0 = src1->nb [is_2D ? 3 : 2 ] / 4 ;
2575
2585
const int32_t ofs1 = src1->nb [is_2D ? 2 : 1 ] / 4 ;
2576
2586
2577
- id <MTLComputePipelineState > pipeline = nil ;
2587
+ id <MTLComputePipelineState > pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_IM2COL_F32].pipeline ;
2588
+
2589
+ const bool is_gt_mttpt = ((size_t )(N * KH * KW)) > pipeline.maxTotalThreadsPerThreadgroup ;
2578
2590
2579
2591
switch (dst->type ) {
2580
- case GGML_TYPE_F32: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_IM2COL_F32].pipeline ; break ;
2581
- case GGML_TYPE_F16: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_IM2COL_F16].pipeline ; break ;
2592
+ case GGML_TYPE_F32: {
2593
+ pipeline = (is_gt_mttpt ?
2594
+ ctx->kernels [GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F32].pipeline
2595
+ :
2596
+ ctx->kernels [GGML_METAL_KERNEL_TYPE_IM2COL_F32].pipeline );
2597
+ } break ;
2598
+ case GGML_TYPE_F16: {
2599
+ pipeline = (is_gt_mttpt ?
2600
+ ctx->kernels [GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16].pipeline
2601
+ :
2602
+ ctx->kernels [GGML_METAL_KERNEL_TYPE_IM2COL_F16].pipeline );
2603
+ } break ;
2582
2604
default : GGML_ABORT (" fatal error" );
2583
2605
};
2584
2606
2585
2607
[encoder setComputePipelineState: pipeline];
2586
- [encoder setBuffer: id_src1 offset: offs_src1 atIndex: 0 ];
2587
- [encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
2588
- [encoder setBytes: &ofs0 length: sizeof ( int32_t ) atIndex: 2 ];
2589
- [encoder setBytes: &ofs1 length: sizeof ( int32_t ) atIndex: 3 ];
2590
- [encoder setBytes: &IW length: sizeof ( int32_t ) atIndex: 4 ];
2591
- [encoder setBytes: &IH length: sizeof ( int32_t ) atIndex: 5 ];
2592
- [encoder setBytes: &CHW length: sizeof ( int32_t ) atIndex: 6 ];
2593
- [encoder setBytes: &s0 length: sizeof ( int32_t ) atIndex: 7 ];
2594
- [encoder setBytes: &s1 length: sizeof ( int32_t ) atIndex: 8 ];
2595
- [encoder setBytes: &p0 length: sizeof ( int32_t ) atIndex: 9 ];
2596
- [encoder setBytes: &p1 length: sizeof ( int32_t ) atIndex: 10 ];
2597
- [encoder setBytes: &d0 length: sizeof ( int32_t ) atIndex: 11 ];
2598
- [encoder setBytes: &d1 length: sizeof ( int32_t ) atIndex: 12 ];
2599
-
2600
- [encoder dispatchThreadgroups: MTLSizeMake (IC, OH, OW) threadsPerThreadgroup: MTLSizeMake (N, KH, KW)];
2608
+ [encoder setBuffer: id_src1 offset: offs_src1 atIndex: 0 ];
2609
+ [encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
2610
+ [encoder setBytes: &ofs0 length: sizeof (int32_t ) atIndex: 2 ];
2611
+ [encoder setBytes: &ofs1 length: sizeof (int32_t ) atIndex: 3 ];
2612
+ [encoder setBytes: &IW length: sizeof (int32_t ) atIndex: 4 ];
2613
+ [encoder setBytes: &IH length: sizeof (int32_t ) atIndex: 5 ];
2614
+ [encoder setBytes: &CHW length: sizeof (int32_t ) atIndex: 6 ];
2615
+ [encoder setBytes: &s0 length: sizeof (int32_t ) atIndex: 7 ];
2616
+ [encoder setBytes: &s1 length: sizeof (int32_t ) atIndex: 8 ];
2617
+ [encoder setBytes: &p0 length: sizeof (int32_t ) atIndex: 9 ];
2618
+ [encoder setBytes: &p1 length: sizeof (int32_t ) atIndex: 10 ];
2619
+ [encoder setBytes: &d0 length: sizeof (int32_t ) atIndex: 11 ];
2620
+ [encoder setBytes: &d1 length: sizeof (int32_t ) atIndex: 12 ];
2621
+
2622
+ if (is_gt_mttpt) {
2623
+ [encoder setBytes: &N length: sizeof (int32_t ) atIndex: 13 ];
2624
+ [encoder setBytes: &KH length: sizeof (int32_t ) atIndex: 14 ];
2625
+ [encoder setBytes: &KW length: sizeof (int32_t ) atIndex: 15 ];
2626
+
2627
+ const uint64_t n_threads = MIN (pipeline.maxTotalThreadsPerThreadgroup , (uint64_t )N);
2628
+
2629
+ const int64_t quotient = N / n_threads + (N % n_threads > 0 ? 1 : 0 );
2630
+
2631
+ [encoder dispatchThreadgroups: MTLSizeMake (quotient * CHW, OH, OW) threadsPerThreadgroup: MTLSizeMake (n_threads, 1 , 1 )];
2632
+ } else {
2633
+ [encoder dispatchThreadgroups: MTLSizeMake (IC, OH, OW) threadsPerThreadgroup: MTLSizeMake (N, KH, KW)];
2634
+ }
2601
2635
} break ;
2602
2636
case GGML_OP_UPSCALE:
2603
2637
{
@@ -3001,6 +3035,64 @@ static void ggml_metal_encode_node(
3001
3035
3002
3036
[encoder dispatchThreadgroups: MTLSizeMake (ne01, ne02, ne03) threadsPerThreadgroup: MTLSizeMake (nth, 1 , 1 )];
3003
3037
} break ;
3038
+ case GGML_OP_POOL_2D:
3039
+ {
3040
+ GGML_ASSERT (ggml_is_contiguous (src0));
3041
+ GGML_ASSERT (src0t == GGML_TYPE_F32 && src0t == dstt);
3042
+
3043
+ const int32_t * opts = dst->op_params ;
3044
+ enum ggml_op_pool op = opts[0 ];
3045
+
3046
+ id <MTLComputePipelineState > pipeline = nil ;
3047
+ switch (src0t) {
3048
+ case GGML_TYPE_F32: {
3049
+ switch (op) {
3050
+ case GGML_OP_POOL_AVG:
3051
+ pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32].pipeline ; break ;
3052
+ case GGML_OP_POOL_MAX:
3053
+ pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32].pipeline ; break ;
3054
+ default : GGML_ASSERT (false && " not implemented" );
3055
+ }
3056
+ } break ;
3057
+ default : GGML_ASSERT (false && " not implemented" );
3058
+ }
3059
+
3060
+ const int32_t k0 = opts[1 ];
3061
+ const int32_t k1 = opts[2 ];
3062
+ const int32_t s0 = opts[3 ];
3063
+ const int32_t s1 = opts[4 ];
3064
+ const int32_t p0 = opts[5 ];
3065
+ const int32_t p1 = opts[6 ];
3066
+
3067
+ const int64_t IH = src0->ne [1 ];
3068
+ const int64_t IW = src0->ne [0 ];
3069
+
3070
+ const int64_t N = dst->ne [3 ];
3071
+ const int64_t OC = dst->ne [2 ];
3072
+ const int64_t OH = dst->ne [1 ];
3073
+ const int64_t OW = dst->ne [0 ];
3074
+
3075
+ const int64_t parallel_elements = N * OC * OH * OW;
3076
+ const int64_t n_threads = MIN ((int64_t )[pipeline maxTotalThreadsPerThreadgroup ], parallel_elements);
3077
+ const int64_t n_tg = (parallel_elements + n_threads - 1 ) / n_threads;
3078
+
3079
+ [encoder setComputePipelineState: pipeline];
3080
+ [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
3081
+ [encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
3082
+ [encoder setBytes: &k0 length: sizeof (int32_t ) atIndex: 2 ];
3083
+ [encoder setBytes: &k1 length: sizeof (int32_t ) atIndex: 3 ];
3084
+ [encoder setBytes: &s0 length: sizeof (int32_t ) atIndex: 4 ];
3085
+ [encoder setBytes: &s1 length: sizeof (int32_t ) atIndex: 5 ];
3086
+ [encoder setBytes: &p0 length: sizeof (int32_t ) atIndex: 6 ];
3087
+ [encoder setBytes: &p1 length: sizeof (int32_t ) atIndex: 7 ];
3088
+ [encoder setBytes: &IH length: sizeof (int64_t ) atIndex: 8 ];
3089
+ [encoder setBytes: &IW length: sizeof (int64_t ) atIndex: 9 ];
3090
+ [encoder setBytes: &OH length: sizeof (int64_t ) atIndex: 10 ];
3091
+ [encoder setBytes: &OW length: sizeof (int64_t ) atIndex: 11 ];
3092
+ [encoder setBytes: ¶llel_elements length: sizeof (int64_t ) atIndex: 12 ];
3093
+
3094
+ [encoder dispatchThreadgroups: MTLSizeMake (n_tg, 1 , 1 ) threadsPerThreadgroup: MTLSizeMake (n_threads, 1 , 1 )];
3095
+ } break ;
3004
3096
default :
3005
3097
{
3006
3098
GGML_LOG_ERROR (" %s : error: node %3d , op = %8s not implemented\n " , __func__, idx, ggml_op_name (dst->op ));
0 commit comments