@@ -1243,8 +1243,6 @@ static void ggml_metal_encode_node(
1243
1243
1244
1244
bool bcast_row = false ;
1245
1245
1246
- int64_t nb = ne00; // used by the "row" kernels
1247
-
1248
1246
id <MTLComputePipelineState > pipeline = nil ;
1249
1247
1250
1248
if (ggml_nelements (src1) == ne10 && ggml_is_contiguous (src1) && ne00 % 4 == 0 && ne10 % 4 == 0 ) {
@@ -1253,7 +1251,6 @@ static void ggml_metal_encode_node(
1253
1251
// src1 is a row
1254
1252
GGML_ASSERT (ne11 == 1 );
1255
1253
1256
- nb = ne00 / 4 ;
1257
1254
switch (dst->op ) {
1258
1255
case GGML_OP_ADD: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD_ROW].pipeline ; break ;
1259
1256
case GGML_OP_SUB: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_SUB_ROW].pipeline ; break ;
@@ -1273,36 +1270,39 @@ static void ggml_metal_encode_node(
1273
1270
}
1274
1271
}
1275
1272
1273
+ ggml_metal_kargs_bin args = {
1274
+ /* .ne00 =*/ ne00,
1275
+ /* .ne01 =*/ ne01,
1276
+ /* .ne02 =*/ ne02,
1277
+ /* .ne03 =*/ ne03,
1278
+ /* .nb00 =*/ nb00,
1279
+ /* .nb01 =*/ nb01,
1280
+ /* .nb02 =*/ nb02,
1281
+ /* .nb03 =*/ nb03,
1282
+ /* .ne10 =*/ ne10,
1283
+ /* .ne11 =*/ ne11,
1284
+ /* .ne12 =*/ ne12,
1285
+ /* .ne13 =*/ ne13,
1286
+ /* .nb10 =*/ nb10,
1287
+ /* .nb11 =*/ nb11,
1288
+ /* .nb12 =*/ nb12,
1289
+ /* .nb13 =*/ nb13,
1290
+ /* .ne0 =*/ ne0,
1291
+ /* .ne1 =*/ ne1,
1292
+ /* .ne2 =*/ ne2,
1293
+ /* .ne3 =*/ ne3,
1294
+ /* .nb0 =*/ nb0,
1295
+ /* .nb1 =*/ nb1,
1296
+ /* .nb2 =*/ nb2,
1297
+ /* .nb3 =*/ nb3,
1298
+ /* .offs =*/ offs,
1299
+ };
1300
+
1276
1301
[encoder setComputePipelineState: pipeline];
1277
- [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
1278
- [encoder setBuffer: id_src1 offset: offs_src1 atIndex: 1 ];
1279
- [encoder setBuffer: id_dst offset: offs_dst atIndex: 2 ];
1280
- [encoder setBytes: &ne00 length: sizeof (ne00) atIndex: 3 ];
1281
- [encoder setBytes: &ne01 length: sizeof (ne01) atIndex: 4 ];
1282
- [encoder setBytes: &ne02 length: sizeof (ne02) atIndex: 5 ];
1283
- [encoder setBytes: &ne03 length: sizeof (ne03) atIndex: 6 ];
1284
- [encoder setBytes: &nb00 length: sizeof (nb00) atIndex: 7 ];
1285
- [encoder setBytes: &nb01 length: sizeof (nb01) atIndex: 8 ];
1286
- [encoder setBytes: &nb02 length: sizeof (nb02) atIndex: 9 ];
1287
- [encoder setBytes: &nb03 length: sizeof (nb03) atIndex: 10 ];
1288
- [encoder setBytes: &ne10 length: sizeof (ne10) atIndex: 11 ];
1289
- [encoder setBytes: &ne11 length: sizeof (ne11) atIndex: 12 ];
1290
- [encoder setBytes: &ne12 length: sizeof (ne12) atIndex: 13 ];
1291
- [encoder setBytes: &ne13 length: sizeof (ne13) atIndex: 14 ];
1292
- [encoder setBytes: &nb10 length: sizeof (nb10) atIndex: 15 ];
1293
- [encoder setBytes: &nb11 length: sizeof (nb11) atIndex: 16 ];
1294
- [encoder setBytes: &nb12 length: sizeof (nb12) atIndex: 17 ];
1295
- [encoder setBytes: &nb13 length: sizeof (nb13) atIndex: 18 ];
1296
- [encoder setBytes: &ne0 length: sizeof (ne0) atIndex: 19 ];
1297
- [encoder setBytes: &ne1 length: sizeof (ne1) atIndex: 20 ];
1298
- [encoder setBytes: &ne2 length: sizeof (ne2) atIndex: 21 ];
1299
- [encoder setBytes: &ne3 length: sizeof (ne3) atIndex: 22 ];
1300
- [encoder setBytes: &nb0 length: sizeof (nb0) atIndex: 23 ];
1301
- [encoder setBytes: &nb1 length: sizeof (nb1) atIndex: 24 ];
1302
- [encoder setBytes: &nb2 length: sizeof (nb2) atIndex: 25 ];
1303
- [encoder setBytes: &nb3 length: sizeof (nb3) atIndex: 26 ];
1304
- [encoder setBytes: &offs length: sizeof (offs) atIndex: 27 ];
1305
- [encoder setBytes: &nb length: sizeof (nb) atIndex: 28 ];
1302
+ [encoder setBytes: &args length: sizeof (args) atIndex: 0 ];
1303
+ [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 1 ];
1304
+ [encoder setBuffer: id_src1 offset: offs_src1 atIndex: 2 ];
1305
+ [encoder setBuffer: id_dst offset: offs_dst atIndex: 3 ];
1306
1306
1307
1307
if (bcast_row) {
1308
1308
const int64_t n = ggml_nelements (dst)/4 ;
@@ -1400,35 +1400,39 @@ static void ggml_metal_encode_node(
1400
1400
1401
1401
const id <MTLComputePipelineState > pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD].pipeline ;
1402
1402
1403
+ ggml_metal_kargs_bin args = {
1404
+ /* .ne00 =*/ ne00,
1405
+ /* .ne01 =*/ ne01,
1406
+ /* .ne02 =*/ ne02,
1407
+ /* .ne03 =*/ ne03,
1408
+ /* .nb00 =*/ nb00,
1409
+ /* .nb01 =*/ pnb1,
1410
+ /* .nb02 =*/ pnb2,
1411
+ /* .nb03 =*/ pnb3,
1412
+ /* .ne10 =*/ ne10,
1413
+ /* .ne11 =*/ ne11,
1414
+ /* .ne12 =*/ ne12,
1415
+ /* .ne13 =*/ ne13,
1416
+ /* .nb10 =*/ nb10,
1417
+ /* .nb11 =*/ nb11,
1418
+ /* .nb12 =*/ nb12,
1419
+ /* .nb13 =*/ nb13,
1420
+ /* .ne0 =*/ ne0,
1421
+ /* .ne1 =*/ ne1,
1422
+ /* .ne2 =*/ ne2,
1423
+ /* .ne3 =*/ ne3,
1424
+ /* .nb0 =*/ nb0,
1425
+ /* .nb1 =*/ pnb1,
1426
+ /* .nb2 =*/ pnb2,
1427
+ /* .nb3 =*/ pnb3,
1428
+ /* .offs =*/ offs,
1429
+ };
1430
+
1403
1431
[encoder setComputePipelineState: pipeline];
1404
- [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
1405
- [encoder setBuffer: id_src1 offset: offs_src1 atIndex: 1 ];
1406
- [encoder setBuffer: id_dst offset: offs_dst atIndex: 2 ];
1407
- [encoder setBytes: &ne00 length: sizeof (ne00) atIndex: 3 ];
1408
- [encoder setBytes: &ne01 length: sizeof (ne01) atIndex: 4 ];
1409
- [encoder setBytes: &ne02 length: sizeof (ne02) atIndex: 5 ];
1410
- [encoder setBytes: &ne03 length: sizeof (ne03) atIndex: 6 ];
1411
- [encoder setBytes: &nb00 length: sizeof (nb00) atIndex: 7 ];
1412
- [encoder setBytes: &pnb1 length: sizeof (pnb1) atIndex: 8 ];
1413
- [encoder setBytes: &pnb2 length: sizeof (pnb2) atIndex: 9 ];
1414
- [encoder setBytes: &pnb3 length: sizeof (pnb3) atIndex: 10 ];
1415
- [encoder setBytes: &ne10 length: sizeof (ne10) atIndex: 11 ];
1416
- [encoder setBytes: &ne11 length: sizeof (ne11) atIndex: 12 ];
1417
- [encoder setBytes: &ne12 length: sizeof (ne12) atIndex: 13 ];
1418
- [encoder setBytes: &ne13 length: sizeof (ne13) atIndex: 14 ];
1419
- [encoder setBytes: &nb10 length: sizeof (nb10) atIndex: 15 ];
1420
- [encoder setBytes: &nb11 length: sizeof (nb11) atIndex: 16 ];
1421
- [encoder setBytes: &nb12 length: sizeof (nb12) atIndex: 17 ];
1422
- [encoder setBytes: &nb13 length: sizeof (nb13) atIndex: 18 ];
1423
- [encoder setBytes: &ne0 length: sizeof (ne0) atIndex: 19 ];
1424
- [encoder setBytes: &ne1 length: sizeof (ne1) atIndex: 20 ];
1425
- [encoder setBytes: &ne2 length: sizeof (ne2) atIndex: 21 ];
1426
- [encoder setBytes: &ne3 length: sizeof (ne3) atIndex: 22 ];
1427
- [encoder setBytes: &nb0 length: sizeof (nb0) atIndex: 23 ];
1428
- [encoder setBytes: &pnb1 length: sizeof (pnb1) atIndex: 24 ];
1429
- [encoder setBytes: &pnb2 length: sizeof (pnb2) atIndex: 25 ];
1430
- [encoder setBytes: &pnb3 length: sizeof (pnb3) atIndex: 26 ];
1431
- [encoder setBytes: &offs length: sizeof (offs) atIndex: 27 ];
1432
+ [encoder setBytes: &args length: sizeof (args) atIndex: 0 ];
1433
+ [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 1 ];
1434
+ [encoder setBuffer: id_src1 offset: offs_src1 atIndex: 2 ];
1435
+ [encoder setBuffer: id_dst offset: offs_dst atIndex: 3 ];
1432
1436
1433
1437
const int nth = MIN ((int ) pipeline.maxTotalThreadsPerThreadgroup , ne00);
1434
1438
0 commit comments