@@ -1481,10 +1481,10 @@ static void ggml_metal_encode_node(
1481
1481
memcpy (&max, ((const int32_t *) dst->op_params ) + 1 , sizeof (float ));
1482
1482
1483
1483
[encoder setComputePipelineState: pipeline];
1484
- [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
1485
- [encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
1486
- [encoder setBytes: &min length: sizeof (min) atIndex: 2 ];
1487
- [encoder setBytes: &max length: sizeof (max) atIndex: 3 ];
1484
+ [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
1485
+ [encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
1486
+ [encoder setBytes: &min length: sizeof (min) atIndex: 2 ];
1487
+ [encoder setBytes: &max length: sizeof (max) atIndex: 3 ];
1488
1488
1489
1489
const int64_t n = ggml_nelements (dst);
1490
1490
@@ -1656,6 +1656,7 @@ static void ggml_metal_encode_node(
1656
1656
1657
1657
id <MTLComputePipelineState > pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline ;
1658
1658
1659
+ // TODO: add ggml_metal_kargs struct
1659
1660
[encoder setComputePipelineState: pipeline];
1660
1661
[encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
1661
1662
[encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
@@ -1731,6 +1732,8 @@ static void ggml_metal_encode_node(
1731
1732
const float m0 = powf (2 .0f , -(max_bias ) / n_head_log2);
1732
1733
const float m1 = powf (2 .0f , -(max_bias / 2 .0f ) / n_head_log2);
1733
1734
1735
+ // TODO: add ggml_metal_kargs struct
1736
+ // TODO: optimize (see https://github.com/ggerganov/llama.cpp/pull/10238/commits/7941b6b9ec29a2866fec6fa6c51612515ca509f6)
1734
1737
[encoder setComputePipelineState: pipeline];
1735
1738
[encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
1736
1739
if (id_src1) {
@@ -1747,6 +1750,7 @@ static void ggml_metal_encode_node(
1747
1750
[encoder setBytes: &m0 length: sizeof (m0) atIndex: 8 ];
1748
1751
[encoder setBytes: &m1 length: sizeof (m1) atIndex: 9 ];
1749
1752
[encoder setBytes: &n_head_log2 length: sizeof (n_head_log2) atIndex: 10 ];
1753
+
1750
1754
[encoder setThreadgroupMemoryLength: 32 *sizeof (float ) atIndex: 0 ];
1751
1755
1752
1756
[encoder dispatchThreadgroups: MTLSizeMake (ne01*ne02*ne03, 1 , 1 ) threadsPerThreadgroup: MTLSizeMake (nth, 1 , 1 )];
@@ -1763,6 +1767,7 @@ static void ggml_metal_encode_node(
1763
1767
pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF].pipeline ;
1764
1768
}
1765
1769
1770
+ // TODO: add ggml_metal_kargs struct
1766
1771
[encoder setComputePipelineState: pipeline];
1767
1772
[encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
1768
1773
[encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
@@ -1787,6 +1792,7 @@ static void ggml_metal_encode_node(
1787
1792
1788
1793
id <MTLComputePipelineState > pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_SSM_CONV_F32].pipeline ;
1789
1794
1795
+ // TODO: add ggml_metal_kargs struct
1790
1796
[encoder setComputePipelineState: pipeline];
1791
1797
[encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
1792
1798
[encoder setBuffer: id_src1 offset: offs_src1 atIndex: 1 ];
@@ -1857,6 +1863,7 @@ static void ggml_metal_encode_node(
1857
1863
1858
1864
id <MTLComputePipelineState > pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32].pipeline ;
1859
1865
1866
+ // TODO: add ggml_metal_kargs struct
1860
1867
[encoder setComputePipelineState: pipeline];
1861
1868
[encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
1862
1869
[encoder setBuffer: id_src1 offset: offs_src1 atIndex: 1 ];
@@ -2595,6 +2602,7 @@ static void ggml_metal_encode_node(
2595
2602
default : GGML_ABORT (" not implemented" );
2596
2603
}
2597
2604
2605
+ // TODO: add ggml_metal_kargs struct
2598
2606
[encoder setComputePipelineState: pipeline];
2599
2607
[encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
2600
2608
[encoder setBuffer: id_src1 offset: offs_src1 atIndex: 1 ];
@@ -2664,6 +2672,7 @@ static void ggml_metal_encode_node(
2664
2672
2665
2673
id <MTLComputePipelineState > pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_GROUP_NORM].pipeline ;
2666
2674
2675
+ // TODO: add ggml_metal_kargs struct
2667
2676
[encoder setComputePipelineState: pipeline];
2668
2677
[encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
2669
2678
[encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
@@ -2853,6 +2862,7 @@ static void ggml_metal_encode_node(
2853
2862
default : GGML_ABORT (" fatal error" );
2854
2863
};
2855
2864
2865
+ // TODO: add ggml_metal_kargs struct
2856
2866
[encoder setComputePipelineState: pipeline];
2857
2867
[encoder setBuffer: id_src1 offset: offs_src1 atIndex: 0 ];
2858
2868
[encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
@@ -2893,6 +2903,7 @@ static void ggml_metal_encode_node(
2893
2903
2894
2904
const id <MTLComputePipelineState > pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_UPSCALE_F32].pipeline ;
2895
2905
2906
+ // TODO: add ggml_metal_kargs struct
2896
2907
[encoder setComputePipelineState: pipeline];
2897
2908
[encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
2898
2909
[encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
@@ -2927,6 +2938,7 @@ static void ggml_metal_encode_node(
2927
2938
2928
2939
id <MTLComputePipelineState > pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_PAD_F32].pipeline ;
2929
2940
2941
+ // TODO: add ggml_metal_kargs struct
2930
2942
[encoder setComputePipelineState: pipeline];
2931
2943
[encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
2932
2944
[encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
@@ -2963,6 +2975,7 @@ static void ggml_metal_encode_node(
2963
2975
2964
2976
id <MTLComputePipelineState > pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ARANGE_F32].pipeline ;
2965
2977
2978
+ // TODO: add ggml_metal_kargs struct
2966
2979
[encoder setComputePipelineState: pipeline];
2967
2980
[encoder setBuffer: id_dst offset: offs_dst atIndex: 0 ];
2968
2981
[encoder setBytes: &ne0 length: sizeof (ne0) atIndex: 1 ];
@@ -2984,6 +2997,7 @@ static void ggml_metal_encode_node(
2984
2997
2985
2998
id <MTLComputePipelineState > pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32].pipeline ;
2986
2999
3000
+ // TODO: add ggml_metal_kargs struct
2987
3001
[encoder setComputePipelineState: pipeline];
2988
3002
[encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
2989
3003
[encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
@@ -3022,6 +3036,7 @@ static void ggml_metal_encode_node(
3022
3036
default : GGML_ABORT (" fatal error" );
3023
3037
};
3024
3038
3039
+ // TODO: add ggml_metal_kargs struct
3025
3040
[encoder setComputePipelineState: pipeline];
3026
3041
[encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
3027
3042
[encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
@@ -3040,6 +3055,7 @@ static void ggml_metal_encode_node(
3040
3055
3041
3056
id <MTLComputePipelineState > pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32].pipeline ;
3042
3057
3058
+ // TODO: add ggml_metal_kargs struct
3043
3059
[encoder setComputePipelineState: pipeline];
3044
3060
[encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
3045
3061
[encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
@@ -3517,6 +3533,7 @@ static void ggml_metal_encode_node(
3517
3533
const int64_t n_threads = MIN ((int64_t )[pipeline maxTotalThreadsPerThreadgroup ], parallel_elements);
3518
3534
const int64_t n_tg = (parallel_elements + n_threads - 1 ) / n_threads;
3519
3535
3536
+ // TODO: add ggml_metal_kargs struct
3520
3537
[encoder setComputePipelineState: pipeline];
3521
3538
[encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
3522
3539
[encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
0 commit comments