Skip to content

Commit 54c859e

Browse files
committed
metal : GGML_OP_ADD, GGML_OP_SUB, GGML_OP_MUL, GGML_OP_DIV
1 parent e418ccf commit 54c859e

File tree

3 files changed

+164
-222
lines changed

3 files changed

+164
-222
lines changed

ggml/src/ggml-common.h

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -447,6 +447,34 @@ typedef struct {
447447
int32_t dim;
448448
} ggml_metal_kargs_concat;
449449

450+
typedef struct {
451+
int32_t ne00;
452+
int32_t ne01;
453+
int32_t ne02;
454+
int32_t ne03;
455+
uint64_t nb00;
456+
uint64_t nb01;
457+
uint64_t nb02;
458+
uint64_t nb03;
459+
int32_t ne10;
460+
int32_t ne11;
461+
int32_t ne12;
462+
int32_t ne13;
463+
uint64_t nb10;
464+
uint64_t nb11;
465+
uint64_t nb12;
466+
uint64_t nb13;
467+
int32_t ne0;
468+
int32_t ne1;
469+
int32_t ne2;
470+
int32_t ne3;
471+
uint64_t nb0;
472+
uint64_t nb1;
473+
uint64_t nb2;
474+
uint64_t nb3;
475+
uint64_t offs;
476+
} ggml_metal_kargs_bin;
477+
450478
typedef struct {
451479
int32_t ne00;
452480
int32_t ne01;

ggml/src/ggml-metal.m

Lines changed: 64 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -1243,8 +1243,6 @@ static void ggml_metal_encode_node(
12431243

12441244
bool bcast_row = false;
12451245

1246-
int64_t nb = ne00; // used by the "row" kernels
1247-
12481246
id<MTLComputePipelineState> pipeline = nil;
12491247

12501248
if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) {
@@ -1253,7 +1251,6 @@ static void ggml_metal_encode_node(
12531251
// src1 is a row
12541252
GGML_ASSERT(ne11 == 1);
12551253

1256-
nb = ne00 / 4;
12571254
switch (dst->op) {
12581255
case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW].pipeline; break;
12591256
case GGML_OP_SUB: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUB_ROW].pipeline; break;
@@ -1273,36 +1270,39 @@ static void ggml_metal_encode_node(
12731270
}
12741271
}
12751272

1273+
ggml_metal_kargs_bin args = {
1274+
/*.ne00 =*/ ne00,
1275+
/*.ne01 =*/ ne01,
1276+
/*.ne02 =*/ ne02,
1277+
/*.ne03 =*/ ne03,
1278+
/*.nb00 =*/ nb00,
1279+
/*.nb01 =*/ nb01,
1280+
/*.nb02 =*/ nb02,
1281+
/*.nb03 =*/ nb03,
1282+
/*.ne10 =*/ ne10,
1283+
/*.ne11 =*/ ne11,
1284+
/*.ne12 =*/ ne12,
1285+
/*.ne13 =*/ ne13,
1286+
/*.nb10 =*/ nb10,
1287+
/*.nb11 =*/ nb11,
1288+
/*.nb12 =*/ nb12,
1289+
/*.nb13 =*/ nb13,
1290+
/*.ne0 =*/ ne0,
1291+
/*.ne1 =*/ ne1,
1292+
/*.ne2 =*/ ne2,
1293+
/*.ne3 =*/ ne3,
1294+
/*.nb0 =*/ nb0,
1295+
/*.nb1 =*/ nb1,
1296+
/*.nb2 =*/ nb2,
1297+
/*.nb3 =*/ nb3,
1298+
/*.offs =*/ offs,
1299+
};
1300+
12761301
[encoder setComputePipelineState:pipeline];
1277-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1278-
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1279-
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1280-
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
1281-
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
1282-
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
1283-
[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
1284-
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
1285-
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8];
1286-
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9];
1287-
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10];
1288-
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
1289-
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
1290-
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
1291-
[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
1292-
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
1293-
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
1294-
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
1295-
[encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
1296-
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19];
1297-
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20];
1298-
[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21];
1299-
[encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22];
1300-
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23];
1301-
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24];
1302-
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25];
1303-
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26];
1304-
[encoder setBytes:&offs length:sizeof(offs) atIndex:27];
1305-
[encoder setBytes:&nb length:sizeof(nb) atIndex:28];
1302+
[encoder setBytes:&args length:sizeof(args) atIndex:0];
1303+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
1304+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
1305+
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];
13061306

13071307
if (bcast_row) {
13081308
const int64_t n = ggml_nelements(dst)/4;
@@ -1400,35 +1400,39 @@ static void ggml_metal_encode_node(
14001400

14011401
const id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD].pipeline;
14021402

1403+
ggml_metal_kargs_bin args = {
1404+
/*.ne00 =*/ ne00,
1405+
/*.ne01 =*/ ne01,
1406+
/*.ne02 =*/ ne02,
1407+
/*.ne03 =*/ ne03,
1408+
/*.nb00 =*/ nb00,
1409+
/*.nb01 =*/ pnb1,
1410+
/*.nb02 =*/ pnb2,
1411+
/*.nb03 =*/ pnb3,
1412+
/*.ne10 =*/ ne10,
1413+
/*.ne11 =*/ ne11,
1414+
/*.ne12 =*/ ne12,
1415+
/*.ne13 =*/ ne13,
1416+
/*.nb10 =*/ nb10,
1417+
/*.nb11 =*/ nb11,
1418+
/*.nb12 =*/ nb12,
1419+
/*.nb13 =*/ nb13,
1420+
/*.ne0 =*/ ne0,
1421+
/*.ne1 =*/ ne1,
1422+
/*.ne2 =*/ ne2,
1423+
/*.ne3 =*/ ne3,
1424+
/*.nb0 =*/ nb0,
1425+
/*.nb1 =*/ pnb1,
1426+
/*.nb2 =*/ pnb2,
1427+
/*.nb3 =*/ pnb3,
1428+
/*.offs =*/ offs,
1429+
};
1430+
14031431
[encoder setComputePipelineState:pipeline];
1404-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1405-
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1406-
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1407-
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
1408-
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
1409-
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
1410-
[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
1411-
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
1412-
[encoder setBytes:&pnb1 length:sizeof(pnb1) atIndex:8];
1413-
[encoder setBytes:&pnb2 length:sizeof(pnb2) atIndex:9];
1414-
[encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:10];
1415-
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
1416-
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
1417-
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
1418-
[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
1419-
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
1420-
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
1421-
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
1422-
[encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
1423-
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19];
1424-
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20];
1425-
[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21];
1426-
[encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22];
1427-
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23];
1428-
[encoder setBytes:&pnb1 length:sizeof(pnb1) atIndex:24];
1429-
[encoder setBytes:&pnb2 length:sizeof(pnb2) atIndex:25];
1430-
[encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:26];
1431-
[encoder setBytes:&offs length:sizeof(offs) atIndex:27];
1432+
[encoder setBytes:&args length:sizeof(args) atIndex:0];
1433+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
1434+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
1435+
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];
14321436

14331437
const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00);
14341438

0 commit comments

Comments
 (0)