Skip to content

Commit 5bb31cc

Browse files
committed
metal : add GGML_OP_REPEAT kernels
ggml-ci
1 parent 62bfef5 commit 5bb31cc

File tree

2 files changed

+95
-7
lines changed

2 files changed

+95
-7
lines changed

ggml-metal.m

Lines changed: 48 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,10 @@
3535
GGML_METAL_KERNEL_TYPE_MUL_ROW,
3636
GGML_METAL_KERNEL_TYPE_DIV,
3737
GGML_METAL_KERNEL_TYPE_DIV_ROW,
38+
GGML_METAL_KERNEL_TYPE_REPEAT_F32,
39+
GGML_METAL_KERNEL_TYPE_REPEAT_F16,
40+
GGML_METAL_KERNEL_TYPE_REPEAT_I32,
41+
GGML_METAL_KERNEL_TYPE_REPEAT_I16,
3842
GGML_METAL_KERNEL_TYPE_SCALE,
3943
GGML_METAL_KERNEL_TYPE_SCALE_4,
4044
GGML_METAL_KERNEL_TYPE_CLAMP,
@@ -485,6 +489,10 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
485489
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW, mul_row, true);
486490
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV, div, true);
487491
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW, div_row, true);
492+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F32, repeat_f32, true);
493+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F16, repeat_f16, true);
494+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_I32, repeat_i32, true);
495+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_I16, repeat_i16, true);
488496
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE, scale, true);
489497
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE_4, scale_4, true);
490498
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CLAMP, clamp, true);
@@ -746,6 +754,7 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
746754
case GGML_OP_ACC:
747755
case GGML_OP_MUL:
748756
case GGML_OP_DIV:
757+
case GGML_OP_REPEAT:
749758
case GGML_OP_SCALE:
750759
case GGML_OP_CLAMP:
751760
case GGML_OP_SQR:
@@ -979,8 +988,6 @@ static enum ggml_status ggml_metal_graph_compute(
979988
switch (dst->op) {
980989
case GGML_OP_CONCAT:
981990
{
982-
const int64_t nb = ne00;
983-
984991
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CONCAT].pipeline;
985992

986993
[encoder setComputePipelineState:pipeline];
@@ -1011,7 +1018,6 @@ static enum ggml_status ggml_metal_graph_compute(
10111018
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24];
10121019
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25];
10131020
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26];
1014-
[encoder setBytes:&nb length:sizeof(nb) atIndex:27];
10151021

10161022
const int nth = MIN(1024, ne0);
10171023

@@ -1021,12 +1027,13 @@ static enum ggml_status ggml_metal_graph_compute(
10211027
case GGML_OP_MUL:
10221028
case GGML_OP_DIV:
10231029
{
1030+
GGML_ASSERT(src0t == GGML_TYPE_F32);
1031+
GGML_ASSERT(src1t == GGML_TYPE_F32);
1032+
10241033
const size_t offs = 0;
10251034

10261035
bool bcast_row = false;
10271036

1028-
int64_t nb = ne00;
1029-
10301037
id<MTLComputePipelineState> pipeline = nil;
10311038

10321039
if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) {
@@ -1035,7 +1042,6 @@ static enum ggml_status ggml_metal_graph_compute(
10351042
// src1 is a row
10361043
GGML_ASSERT(ne11 == 1);
10371044

1038-
nb = ne00 / 4;
10391045
switch (dst->op) {
10401046
case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW].pipeline; break;
10411047
case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_ROW].pipeline; break;
@@ -1082,7 +1088,6 @@ static enum ggml_status ggml_metal_graph_compute(
10821088
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25];
10831089
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26];
10841090
[encoder setBytes:&offs length:sizeof(offs) atIndex:27];
1085-
[encoder setBytes:&nb length:sizeof(nb) atIndex:28];
10861091

10871092
if (bcast_row) {
10881093
const int64_t n = ggml_nelements(dst)/4;
@@ -1094,6 +1099,42 @@ static enum ggml_status ggml_metal_graph_compute(
10941099
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
10951100
}
10961101
} break;
1102+
case GGML_OP_REPEAT:
1103+
{
1104+
id<MTLComputePipelineState> pipeline;
1105+
1106+
switch (src0t) {
1107+
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_F32].pipeline; break;
1108+
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_F16].pipeline; break;
1109+
case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_I32].pipeline; break;
1110+
case GGML_TYPE_I16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_I16].pipeline; break;
1111+
default: GGML_ASSERT(false);
1112+
}
1113+
1114+
[encoder setComputePipelineState:pipeline];
1115+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1116+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1117+
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
1118+
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
1119+
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
1120+
[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
1121+
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
1122+
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
1123+
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
1124+
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
1125+
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10];
1126+
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11];
1127+
[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12];
1128+
[encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13];
1129+
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14];
1130+
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
1131+
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16];
1132+
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
1133+
1134+
const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
1135+
1136+
[encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1137+
} break;
10971138
case GGML_OP_ACC:
10981139
{
10991140
GGML_ASSERT(src0t == GGML_TYPE_F32);

ggml-metal.metal

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,53 @@ kernel void kernel_div(
168168
}
169169
}
170170

171+
template<typename T>
172+
kernel void kernel_repeat(
173+
device const char * src0,
174+
device char * dst,
175+
constant int64_t & ne00,
176+
constant int64_t & ne01,
177+
constant int64_t & ne02,
178+
constant int64_t & ne03,
179+
constant uint64_t & nb00,
180+
constant uint64_t & nb01,
181+
constant uint64_t & nb02,
182+
constant uint64_t & nb03,
183+
constant int64_t & ne0,
184+
constant int64_t & ne1,
185+
constant int64_t & ne2,
186+
constant int64_t & ne3,
187+
constant uint64_t & nb0,
188+
constant uint64_t & nb1,
189+
constant uint64_t & nb2,
190+
constant uint64_t & nb3,
191+
uint3 tgpig[[threadgroup_position_in_grid]],
192+
uint3 tpitg[[thread_position_in_threadgroup]],
193+
uint3 ntg[[threads_per_threadgroup]]) {
194+
const int64_t i3 = tgpig.z;
195+
const int64_t i2 = tgpig.y;
196+
const int64_t i1 = tgpig.x;
197+
198+
const int64_t i03 = i3 % ne03;
199+
const int64_t i02 = i2 % ne02;
200+
const int64_t i01 = i1 % ne01;
201+
202+
device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
203+
device char * dst_ptr = dst + i3*nb3 + i2*nb2 + i1*nb1 ;
204+
205+
for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
206+
const int i00 = i0 % ne00;
207+
*((device T *)(dst_ptr + i0*nb0)) = *((device T *)(src0_ptr + i00*nb00));
208+
}
209+
}
210+
211+
typedef decltype(kernel_repeat<float>) kernel_repeat_t;
212+
213+
template [[host_name("kernel_repeat_f32")]] kernel kernel_repeat_t kernel_repeat<float>;
214+
template [[host_name("kernel_repeat_f16")]] kernel kernel_repeat_t kernel_repeat<half>;
215+
template [[host_name("kernel_repeat_i32")]] kernel kernel_repeat_t kernel_repeat<int>;
216+
template [[host_name("kernel_repeat_i16")]] kernel kernel_repeat_t kernel_repeat<short>;
217+
171218
// assumption: src1 is a row
172219
// broadcast src1 into src0
173220
kernel void kernel_add_row(

0 commit comments

Comments
 (0)