Skip to content

Commit 8e01c25

Browse files
committed
metal : GGML_OP_REPEAT
1 parent 54c859e commit 8e01c25

File tree

3 files changed

+56
-48
lines changed

3 files changed

+56
-48
lines changed

ggml/src/ggml-common.h

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -475,6 +475,25 @@ typedef struct {
475475
uint64_t offs;
476476
} ggml_metal_kargs_bin;
477477

478+
typedef struct {
479+
int32_t ne00;
480+
int32_t ne01;
481+
int32_t ne02;
482+
int32_t ne03;
483+
uint64_t nb00;
484+
uint64_t nb01;
485+
uint64_t nb02;
486+
uint64_t nb03;
487+
int32_t ne0;
488+
int32_t ne1;
489+
int32_t ne2;
490+
int32_t ne3;
491+
uint64_t nb0;
492+
uint64_t nb1;
493+
uint64_t nb2;
494+
uint64_t nb3;
495+
} ggml_metal_kargs_repeat;
496+
478497
typedef struct {
479498
int32_t ne00;
480499
int32_t ne01;

ggml/src/ggml-metal.m

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1326,25 +1326,29 @@ static void ggml_metal_encode_node(
13261326
default: GGML_ABORT("fatal error");
13271327
}
13281328

1329+
ggml_metal_kargs_repeat args = {
1330+
/*.ne00 =*/ ne00,
1331+
/*.ne01 =*/ ne01,
1332+
/*.ne02 =*/ ne02,
1333+
/*.ne03 =*/ ne03,
1334+
/*.nb00 =*/ nb00,
1335+
/*.nb01 =*/ nb01,
1336+
/*.nb02 =*/ nb02,
1337+
/*.nb03 =*/ nb03,
1338+
/*.ne0 =*/ ne0,
1339+
/*.ne1 =*/ ne1,
1340+
/*.ne2 =*/ ne2,
1341+
/*.ne3 =*/ ne3,
1342+
/*.nb0 =*/ nb0,
1343+
/*.nb1 =*/ nb1,
1344+
/*.nb2 =*/ nb2,
1345+
/*.nb3 =*/ nb3,
1346+
};
1347+
13291348
[encoder setComputePipelineState:pipeline];
1330-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1331-
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1332-
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
1333-
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
1334-
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
1335-
[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
1336-
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
1337-
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
1338-
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
1339-
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
1340-
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10];
1341-
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11];
1342-
[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12];
1343-
[encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13];
1344-
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14];
1345-
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
1346-
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16];
1347-
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
1349+
[encoder setBytes:&args length:sizeof(args) atIndex:0];
1350+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
1351+
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
13481352

13491353
const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
13501354

ggml/src/ggml-metal.metal

Lines changed: 15 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -598,41 +598,26 @@ kernel void kernel_div(
598598

599599
template<typename T>
600600
kernel void kernel_repeat(
601+
constant ggml_metal_kargs_repeat & args,
601602
device const char * src0,
602603
device char * dst,
603-
constant int64_t & ne00,
604-
constant int64_t & ne01,
605-
constant int64_t & ne02,
606-
constant int64_t & ne03,
607-
constant uint64_t & nb00,
608-
constant uint64_t & nb01,
609-
constant uint64_t & nb02,
610-
constant uint64_t & nb03,
611-
constant int64_t & ne0,
612-
constant int64_t & ne1,
613-
constant int64_t & ne2,
614-
constant int64_t & ne3,
615-
constant uint64_t & nb0,
616-
constant uint64_t & nb1,
617-
constant uint64_t & nb2,
618-
constant uint64_t & nb3,
619-
uint3 tgpig[[threadgroup_position_in_grid]],
620-
uint3 tpitg[[thread_position_in_threadgroup]],
621-
uint3 ntg[[threads_per_threadgroup]]) {
622-
const int64_t i3 = tgpig.z;
623-
const int64_t i2 = tgpig.y;
624-
const int64_t i1 = tgpig.x;
604+
uint3 tgpig[[threadgroup_position_in_grid]],
605+
ushort3 tpitg[[thread_position_in_threadgroup]],
606+
ushort3 ntg[[threads_per_threadgroup]]) {
607+
const int i3 = tgpig.z;
608+
const int i2 = tgpig.y;
609+
const int i1 = tgpig.x;
625610

626-
const int64_t i03 = i3 % ne03;
627-
const int64_t i02 = i2 % ne02;
628-
const int64_t i01 = i1 % ne01;
611+
const int i03 = i3%args.ne03;
612+
const int i02 = i2%args.ne02;
613+
const int i01 = i1%args.ne01;
629614

630-
device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
631-
device char * dst_ptr = dst + i3*nb3 + i2*nb2 + i1*nb1 ;
615+
device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01;
616+
device char * dst_ptr = dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1;
632617

633-
for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
634-
const int i00 = i0 % ne00;
635-
*((device T *)(dst_ptr + i0*nb0)) = *((device T *)(src0_ptr + i00*nb00));
618+
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
619+
const int i00 = i0%args.ne00;
620+
*((device T *)(dst_ptr + i0*args.nb0)) = *((device T *)(src0_ptr + i00*args.nb00));
636621
}
637622
}
638623

0 commit comments

Comments
 (0)