@@ -1485,10 +1485,10 @@ static void ggml_metal_encode_node(
1485
1485
memcpy (&max, ((const int32_t *) dst->op_params ) + 1 , sizeof (float ));
1486
1486
1487
1487
[encoder setComputePipelineState: pipeline];
1488
- [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
1489
- [encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
1490
- [encoder setBytes: &min length: sizeof (min) atIndex: 2 ];
1491
- [encoder setBytes: &max length: sizeof (max) atIndex: 3 ];
1488
+ [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
1489
+ [encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
1490
+ [encoder setBytes: &min length: sizeof (min) atIndex: 2 ];
1491
+ [encoder setBytes: &max length: sizeof (max) atIndex: 3 ];
1492
1492
1493
1493
const int64_t n = ggml_nelements (dst);
1494
1494
@@ -1660,6 +1660,7 @@ static void ggml_metal_encode_node(
1660
1660
1661
1661
id <MTLComputePipelineState > pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline ;
1662
1662
1663
+ // TODO: add ggml_metal_kargs struct
1663
1664
[encoder setComputePipelineState: pipeline];
1664
1665
[encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
1665
1666
[encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
@@ -1735,6 +1736,8 @@ static void ggml_metal_encode_node(
1735
1736
const float m0 = powf (2 .0f , -(max_bias ) / n_head_log2);
1736
1737
const float m1 = powf (2 .0f , -(max_bias / 2 .0f ) / n_head_log2);
1737
1738
1739
+ // TODO: add ggml_metal_kargs struct
1740
+ // TODO: optimize (see https://github.com/ggerganov/llama.cpp/pull/10238/commits/7941b6b9ec29a2866fec6fa6c51612515ca509f6)
1738
1741
[encoder setComputePipelineState: pipeline];
1739
1742
[encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
1740
1743
if (id_src1) {
@@ -1751,6 +1754,7 @@ static void ggml_metal_encode_node(
1751
1754
[encoder setBytes: &m0 length: sizeof (m0) atIndex: 8 ];
1752
1755
[encoder setBytes: &m1 length: sizeof (m1) atIndex: 9 ];
1753
1756
[encoder setBytes: &n_head_log2 length: sizeof (n_head_log2) atIndex: 10 ];
1757
+
1754
1758
[encoder setThreadgroupMemoryLength: 32 *sizeof (float ) atIndex: 0 ];
1755
1759
1756
1760
[encoder dispatchThreadgroups: MTLSizeMake (ne01*ne02*ne03, 1 , 1 ) threadsPerThreadgroup: MTLSizeMake (nth, 1 , 1 )];
@@ -1767,6 +1771,7 @@ static void ggml_metal_encode_node(
1767
1771
pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF].pipeline ;
1768
1772
}
1769
1773
1774
+ // TODO: add ggml_metal_kargs struct
1770
1775
[encoder setComputePipelineState: pipeline];
1771
1776
[encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
1772
1777
[encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
@@ -1791,6 +1796,7 @@ static void ggml_metal_encode_node(
1791
1796
1792
1797
id <MTLComputePipelineState > pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_SSM_CONV_F32].pipeline ;
1793
1798
1799
+ // TODO: add ggml_metal_kargs struct
1794
1800
[encoder setComputePipelineState: pipeline];
1795
1801
[encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
1796
1802
[encoder setBuffer: id_src1 offset: offs_src1 atIndex: 1 ];
@@ -1861,6 +1867,7 @@ static void ggml_metal_encode_node(
1861
1867
1862
1868
id <MTLComputePipelineState > pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32].pipeline ;
1863
1869
1870
+ // TODO: add ggml_metal_kargs struct
1864
1871
[encoder setComputePipelineState: pipeline];
1865
1872
[encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
1866
1873
[encoder setBuffer: id_src1 offset: offs_src1 atIndex: 1 ];
@@ -2599,6 +2606,7 @@ static void ggml_metal_encode_node(
2599
2606
default : GGML_ABORT (" not implemented" );
2600
2607
}
2601
2608
2609
+ // TODO: add ggml_metal_kargs struct
2602
2610
[encoder setComputePipelineState: pipeline];
2603
2611
[encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
2604
2612
[encoder setBuffer: id_src1 offset: offs_src1 atIndex: 1 ];
@@ -2668,6 +2676,7 @@ static void ggml_metal_encode_node(
2668
2676
2669
2677
id <MTLComputePipelineState > pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_GROUP_NORM].pipeline ;
2670
2678
2679
+ // TODO: add ggml_metal_kargs struct
2671
2680
[encoder setComputePipelineState: pipeline];
2672
2681
[encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
2673
2682
[encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
@@ -2857,6 +2866,7 @@ static void ggml_metal_encode_node(
2857
2866
default : GGML_ABORT (" fatal error" );
2858
2867
};
2859
2868
2869
+ // TODO: add ggml_metal_kargs struct
2860
2870
[encoder setComputePipelineState: pipeline];
2861
2871
[encoder setBuffer: id_src1 offset: offs_src1 atIndex: 0 ];
2862
2872
[encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
@@ -2897,6 +2907,7 @@ static void ggml_metal_encode_node(
2897
2907
2898
2908
const id <MTLComputePipelineState > pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_UPSCALE_F32].pipeline ;
2899
2909
2910
+ // TODO: add ggml_metal_kargs struct
2900
2911
[encoder setComputePipelineState: pipeline];
2901
2912
[encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
2902
2913
[encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
@@ -2931,6 +2942,7 @@ static void ggml_metal_encode_node(
2931
2942
2932
2943
id <MTLComputePipelineState > pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_PAD_F32].pipeline ;
2933
2944
2945
+ // TODO: add ggml_metal_kargs struct
2934
2946
[encoder setComputePipelineState: pipeline];
2935
2947
[encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
2936
2948
[encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
@@ -2967,6 +2979,7 @@ static void ggml_metal_encode_node(
2967
2979
2968
2980
id <MTLComputePipelineState > pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ARANGE_F32].pipeline ;
2969
2981
2982
+ // TODO: add ggml_metal_kargs struct
2970
2983
[encoder setComputePipelineState: pipeline];
2971
2984
[encoder setBuffer: id_dst offset: offs_dst atIndex: 0 ];
2972
2985
[encoder setBytes: &ne0 length: sizeof (ne0) atIndex: 1 ];
@@ -2988,6 +3001,7 @@ static void ggml_metal_encode_node(
2988
3001
2989
3002
id <MTLComputePipelineState > pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32].pipeline ;
2990
3003
3004
+ // TODO: add ggml_metal_kargs struct
2991
3005
[encoder setComputePipelineState: pipeline];
2992
3006
[encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
2993
3007
[encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
@@ -3026,6 +3040,7 @@ static void ggml_metal_encode_node(
3026
3040
default : GGML_ABORT (" fatal error" );
3027
3041
};
3028
3042
3043
+ // TODO: add ggml_metal_kargs struct
3029
3044
[encoder setComputePipelineState: pipeline];
3030
3045
[encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
3031
3046
[encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
@@ -3044,6 +3059,7 @@ static void ggml_metal_encode_node(
3044
3059
3045
3060
id <MTLComputePipelineState > pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32].pipeline ;
3046
3061
3062
+ // TODO: add ggml_metal_kargs struct
3047
3063
[encoder setComputePipelineState: pipeline];
3048
3064
[encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
3049
3065
[encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
@@ -3521,6 +3537,7 @@ static void ggml_metal_encode_node(
3521
3537
const int64_t n_threads = MIN ((int64_t )[pipeline maxTotalThreadsPerThreadgroup ], parallel_elements);
3522
3538
const int64_t n_tg = (parallel_elements + n_threads - 1 ) / n_threads;
3523
3539
3540
+ // TODO: add ggml_metal_kargs struct
3524
3541
[encoder setComputePipelineState: pipeline];
3525
3542
[encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
3526
3543
[encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
0 commit comments