Skip to content

Commit 967727a

Browse files
committed
metal : GGML_OP_CPY
1 parent 8e01c25 commit 967727a

File tree

3 files changed

+182
-261
lines changed

3 files changed

+182
-261
lines changed

ggml/src/ggml-common.h

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -494,6 +494,25 @@ typedef struct {
494494
uint64_t nb3;
495495
} ggml_metal_kargs_repeat;
496496

497+
typedef struct {
498+
int64_t ne00;
499+
int64_t ne01;
500+
int64_t ne02;
501+
int64_t ne03;
502+
uint64_t nb00;
503+
uint64_t nb01;
504+
uint64_t nb02;
505+
uint64_t nb03;
506+
int64_t ne0;
507+
int64_t ne1;
508+
int64_t ne2;
509+
int64_t ne3;
510+
uint64_t nb0;
511+
uint64_t nb1;
512+
uint64_t nb2;
513+
uint64_t nb3;
514+
} ggml_metal_kargs_cpy;
515+
497516
typedef struct {
498517
int32_t ne00;
499518
int32_t ne01;

ggml/src/ggml-metal.m

Lines changed: 44 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1377,25 +1377,29 @@ static void ggml_metal_encode_node(
13771377

13781378
const id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline;
13791379

1380+
ggml_metal_kargs_cpy args = {
1381+
/*.ne00 =*/ ne00,
1382+
/*.ne01 =*/ ne01,
1383+
/*.ne02 =*/ ne02,
1384+
/*.ne03 =*/ ne03,
1385+
/*.nb00 =*/ nb00,
1386+
/*.nb01 =*/ nb01,
1387+
/*.nb02 =*/ nb02,
1388+
/*.nb03 =*/ nb03,
1389+
/*.ne0 =*/ ne0,
1390+
/*.ne1 =*/ ne1,
1391+
/*.ne2 =*/ ne2,
1392+
/*.ne3 =*/ ne3,
1393+
/*.nb0 =*/ nb0,
1394+
/*.nb1 =*/ nb1,
1395+
/*.nb2 =*/ nb2,
1396+
/*.nb3 =*/ nb3,
1397+
};
1398+
13801399
[encoder setComputePipelineState:pipeline];
1381-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1382-
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1383-
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
1384-
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
1385-
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
1386-
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
1387-
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
1388-
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
1389-
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
1390-
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
1391-
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
1392-
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
1393-
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
1394-
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
1395-
[encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
1396-
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
1397-
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
1398-
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
1400+
[encoder setBytes:&args length:sizeof(args) atIndex:0];
1401+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
1402+
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
13991403

14001404
const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00);
14011405

@@ -3425,25 +3429,29 @@ static void ggml_metal_encode_node(
34253429
default: GGML_ABORT("not implemented");
34263430
}
34273431

3432+
ggml_metal_kargs_cpy args = {
3433+
/*.ne00 =*/ ne00,
3434+
/*.ne01 =*/ ne01,
3435+
/*.ne02 =*/ ne02,
3436+
/*.ne03 =*/ ne03,
3437+
/*.nb00 =*/ nb00,
3438+
/*.nb01 =*/ nb01,
3439+
/*.nb02 =*/ nb02,
3440+
/*.nb03 =*/ nb03,
3441+
/*.ne0 =*/ ne0,
3442+
/*.ne1 =*/ ne1,
3443+
/*.ne2 =*/ ne2,
3444+
/*.ne3 =*/ ne3,
3445+
/*.nb0 =*/ nb0,
3446+
/*.nb1 =*/ nb1,
3447+
/*.nb2 =*/ nb2,
3448+
/*.nb3 =*/ nb3,
3449+
};
3450+
34283451
[encoder setComputePipelineState:pipeline];
3429-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
3430-
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
3431-
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
3432-
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
3433-
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
3434-
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
3435-
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
3436-
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
3437-
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
3438-
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
3439-
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
3440-
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
3441-
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
3442-
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
3443-
[encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
3444-
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
3445-
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
3446-
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
3452+
[encoder setBytes:&args length:sizeof(args) atIndex:0];
3453+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
3454+
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
34473455

34483456
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
34493457
} break;

0 commit comments

Comments
 (0)