Skip to content

Commit 061ddc6

Browse files
committed
metal : add TODOs for rest of ops
1 parent 7941b6b commit 061ddc6

File tree

1 file changed

+21
-4
lines changed

1 file changed

+21
-4
lines changed

ggml/src/ggml-metal.m

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1481,10 +1481,10 @@ static void ggml_metal_encode_node(
14811481
memcpy(&max, ((const int32_t *) dst->op_params) + 1, sizeof(float));
14821482

14831483
[encoder setComputePipelineState:pipeline];
1484-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1485-
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1486-
[encoder setBytes:&min length:sizeof(min) atIndex:2];
1487-
[encoder setBytes:&max length:sizeof(max) atIndex:3];
1484+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1485+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1486+
[encoder setBytes:&min length:sizeof(min) atIndex:2];
1487+
[encoder setBytes:&max length:sizeof(max) atIndex:3];
14881488

14891489
const int64_t n = ggml_nelements(dst);
14901490

@@ -1656,6 +1656,7 @@ static void ggml_metal_encode_node(
16561656

16571657
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline;
16581658

1659+
// TODO: add ggml_metal_kargs struct
16591660
[encoder setComputePipelineState:pipeline];
16601661
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
16611662
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
@@ -1731,6 +1732,8 @@ static void ggml_metal_encode_node(
17311732
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
17321733
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
17331734

1735+
// TODO: add ggml_metal_kargs struct
1736+
// TODO: optimize (see https://github.com/ggerganov/llama.cpp/pull/10238/commits/7941b6b9ec29a2866fec6fa6c51612515ca509f6)
17341737
[encoder setComputePipelineState:pipeline];
17351738
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
17361739
if (id_src1) {
@@ -1747,6 +1750,7 @@ static void ggml_metal_encode_node(
17471750
[encoder setBytes:&m0 length:sizeof(m0) atIndex:8];
17481751
[encoder setBytes:&m1 length:sizeof(m1) atIndex:9];
17491752
[encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:10];
1753+
17501754
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
17511755

17521756
[encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
@@ -1763,6 +1767,7 @@ static void ggml_metal_encode_node(
17631767
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF].pipeline;
17641768
}
17651769

1770+
// TODO: add ggml_metal_kargs struct
17661771
[encoder setComputePipelineState:pipeline];
17671772
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
17681773
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
@@ -1787,6 +1792,7 @@ static void ggml_metal_encode_node(
17871792

17881793
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_CONV_F32].pipeline;
17891794

1795+
// TODO: add ggml_metal_kargs struct
17901796
[encoder setComputePipelineState:pipeline];
17911797
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
17921798
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
@@ -1857,6 +1863,7 @@ static void ggml_metal_encode_node(
18571863

18581864
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32].pipeline;
18591865

1866+
// TODO: add ggml_metal_kargs struct
18601867
[encoder setComputePipelineState:pipeline];
18611868
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
18621869
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
@@ -2595,6 +2602,7 @@ static void ggml_metal_encode_node(
25952602
default: GGML_ABORT("not implemented");
25962603
}
25972604

2605+
// TODO: add ggml_metal_kargs struct
25982606
[encoder setComputePipelineState:pipeline];
25992607
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
26002608
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
@@ -2664,6 +2672,7 @@ static void ggml_metal_encode_node(
26642672

26652673
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GROUP_NORM].pipeline;
26662674

2675+
// TODO: add ggml_metal_kargs struct
26672676
[encoder setComputePipelineState:pipeline];
26682677
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
26692678
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
@@ -2853,6 +2862,7 @@ static void ggml_metal_encode_node(
28532862
default: GGML_ABORT("fatal error");
28542863
};
28552864

2865+
// TODO: add ggml_metal_kargs struct
28562866
[encoder setComputePipelineState:pipeline];
28572867
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:0];
28582868
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
@@ -2893,6 +2903,7 @@ static void ggml_metal_encode_node(
28932903

28942904
const id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_UPSCALE_F32].pipeline;
28952905

2906+
// TODO: add ggml_metal_kargs struct
28962907
[encoder setComputePipelineState:pipeline];
28972908
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
28982909
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
@@ -2927,6 +2938,7 @@ static void ggml_metal_encode_node(
29272938

29282939
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_PAD_F32].pipeline;
29292940

2941+
// TODO: add ggml_metal_kargs struct
29302942
[encoder setComputePipelineState:pipeline];
29312943
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
29322944
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
@@ -2963,6 +2975,7 @@ static void ggml_metal_encode_node(
29632975

29642976
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARANGE_F32].pipeline;
29652977

2978+
// TODO: add ggml_metal_kargs struct
29662979
[encoder setComputePipelineState:pipeline];
29672980
[encoder setBuffer:id_dst offset:offs_dst atIndex:0];
29682981
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:1];
@@ -2984,6 +2997,7 @@ static void ggml_metal_encode_node(
29842997

29852998
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32].pipeline;
29862999

3000+
// TODO: add ggml_metal_kargs struct
29873001
[encoder setComputePipelineState:pipeline];
29883002
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
29893003
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
@@ -3022,6 +3036,7 @@ static void ggml_metal_encode_node(
30223036
default: GGML_ABORT("fatal error");
30233037
};
30243038

3039+
// TODO: add ggml_metal_kargs struct
30253040
[encoder setComputePipelineState:pipeline];
30263041
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
30273042
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
@@ -3040,6 +3055,7 @@ static void ggml_metal_encode_node(
30403055

30413056
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32].pipeline;
30423057

3058+
// TODO: add ggml_metal_kargs struct
30433059
[encoder setComputePipelineState:pipeline];
30443060
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
30453061
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
@@ -3517,6 +3533,7 @@ static void ggml_metal_encode_node(
35173533
const int64_t n_threads = MIN((int64_t)[pipeline maxTotalThreadsPerThreadgroup], parallel_elements);
35183534
const int64_t n_tg = (parallel_elements + n_threads - 1) / n_threads;
35193535

3536+
// TODO: add ggml_metal_kargs struct
35203537
[encoder setComputePipelineState:pipeline];
35213538
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
35223539
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];

0 commit comments

Comments
 (0)