Skip to content

Commit 1c60302

Browse files
committed
metal : add TODOs for rest of ops
1 parent f018669 commit 1c60302

File tree

1 file changed

+21
-4
lines changed

1 file changed

+21
-4
lines changed

ggml/src/ggml-metal/ggml-metal.m

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1485,10 +1485,10 @@ static void ggml_metal_encode_node(
14851485
memcpy(&max, ((const int32_t *) dst->op_params) + 1, sizeof(float));
14861486

14871487
[encoder setComputePipelineState:pipeline];
1488-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1489-
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1490-
[encoder setBytes:&min length:sizeof(min) atIndex:2];
1491-
[encoder setBytes:&max length:sizeof(max) atIndex:3];
1488+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1489+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1490+
[encoder setBytes:&min length:sizeof(min) atIndex:2];
1491+
[encoder setBytes:&max length:sizeof(max) atIndex:3];
14921492

14931493
const int64_t n = ggml_nelements(dst);
14941494

@@ -1660,6 +1660,7 @@ static void ggml_metal_encode_node(
16601660

16611661
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline;
16621662

1663+
// TODO: add ggml_metal_kargs struct
16631664
[encoder setComputePipelineState:pipeline];
16641665
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
16651666
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
@@ -1735,6 +1736,8 @@ static void ggml_metal_encode_node(
17351736
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
17361737
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
17371738

1739+
// TODO: add ggml_metal_kargs struct
1740+
// TODO: optimize (see https://github.com/ggerganov/llama.cpp/pull/10238/commits/7941b6b9ec29a2866fec6fa6c51612515ca509f6)
17381741
[encoder setComputePipelineState:pipeline];
17391742
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
17401743
if (id_src1) {
@@ -1751,6 +1754,7 @@ static void ggml_metal_encode_node(
17511754
[encoder setBytes:&m0 length:sizeof(m0) atIndex:8];
17521755
[encoder setBytes:&m1 length:sizeof(m1) atIndex:9];
17531756
[encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:10];
1757+
17541758
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
17551759

17561760
[encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
@@ -1767,6 +1771,7 @@ static void ggml_metal_encode_node(
17671771
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF].pipeline;
17681772
}
17691773

1774+
// TODO: add ggml_metal_kargs struct
17701775
[encoder setComputePipelineState:pipeline];
17711776
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
17721777
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
@@ -1791,6 +1796,7 @@ static void ggml_metal_encode_node(
17911796

17921797
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_CONV_F32].pipeline;
17931798

1799+
// TODO: add ggml_metal_kargs struct
17941800
[encoder setComputePipelineState:pipeline];
17951801
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
17961802
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
@@ -1861,6 +1867,7 @@ static void ggml_metal_encode_node(
18611867

18621868
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32].pipeline;
18631869

1870+
// TODO: add ggml_metal_kargs struct
18641871
[encoder setComputePipelineState:pipeline];
18651872
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
18661873
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
@@ -2599,6 +2606,7 @@ static void ggml_metal_encode_node(
25992606
default: GGML_ABORT("not implemented");
26002607
}
26012608

2609+
// TODO: add ggml_metal_kargs struct
26022610
[encoder setComputePipelineState:pipeline];
26032611
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
26042612
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
@@ -2668,6 +2676,7 @@ static void ggml_metal_encode_node(
26682676

26692677
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GROUP_NORM].pipeline;
26702678

2679+
// TODO: add ggml_metal_kargs struct
26712680
[encoder setComputePipelineState:pipeline];
26722681
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
26732682
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
@@ -2857,6 +2866,7 @@ static void ggml_metal_encode_node(
28572866
default: GGML_ABORT("fatal error");
28582867
};
28592868

2869+
// TODO: add ggml_metal_kargs struct
28602870
[encoder setComputePipelineState:pipeline];
28612871
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:0];
28622872
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
@@ -2897,6 +2907,7 @@ static void ggml_metal_encode_node(
28972907

28982908
const id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_UPSCALE_F32].pipeline;
28992909

2910+
// TODO: add ggml_metal_kargs struct
29002911
[encoder setComputePipelineState:pipeline];
29012912
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
29022913
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
@@ -2931,6 +2942,7 @@ static void ggml_metal_encode_node(
29312942

29322943
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_PAD_F32].pipeline;
29332944

2945+
// TODO: add ggml_metal_kargs struct
29342946
[encoder setComputePipelineState:pipeline];
29352947
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
29362948
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
@@ -2967,6 +2979,7 @@ static void ggml_metal_encode_node(
29672979

29682980
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARANGE_F32].pipeline;
29692981

2982+
// TODO: add ggml_metal_kargs struct
29702983
[encoder setComputePipelineState:pipeline];
29712984
[encoder setBuffer:id_dst offset:offs_dst atIndex:0];
29722985
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:1];
@@ -2988,6 +3001,7 @@ static void ggml_metal_encode_node(
29883001

29893002
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32].pipeline;
29903003

3004+
// TODO: add ggml_metal_kargs struct
29913005
[encoder setComputePipelineState:pipeline];
29923006
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
29933007
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
@@ -3026,6 +3040,7 @@ static void ggml_metal_encode_node(
30263040
default: GGML_ABORT("fatal error");
30273041
};
30283042

3043+
// TODO: add ggml_metal_kargs struct
30293044
[encoder setComputePipelineState:pipeline];
30303045
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
30313046
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
@@ -3044,6 +3059,7 @@ static void ggml_metal_encode_node(
30443059

30453060
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32].pipeline;
30463061

3062+
// TODO: add ggml_metal_kargs struct
30473063
[encoder setComputePipelineState:pipeline];
30483064
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
30493065
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
@@ -3521,6 +3537,7 @@ static void ggml_metal_encode_node(
35213537
const int64_t n_threads = MIN((int64_t)[pipeline maxTotalThreadsPerThreadgroup], parallel_elements);
35223538
const int64_t n_tg = (parallel_elements + n_threads - 1) / n_threads;
35233539

3540+
// TODO: add ggml_metal_kargs struct
35243541
[encoder setComputePipelineState:pipeline];
35253542
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
35263543
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];

0 commit comments

Comments
 (0)