@@ -1377,25 +1377,29 @@ static void ggml_metal_encode_node(
1377
1377
1378
1378
const id <MTLComputePipelineState > pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline ;
1379
1379
1380
+ ggml_metal_kargs_cpy args = {
1381
+ /* .ne00 =*/ ne00,
1382
+ /* .ne01 =*/ ne01,
1383
+ /* .ne02 =*/ ne02,
1384
+ /* .ne03 =*/ ne03,
1385
+ /* .nb00 =*/ nb00,
1386
+ /* .nb01 =*/ nb01,
1387
+ /* .nb02 =*/ nb02,
1388
+ /* .nb03 =*/ nb03,
1389
+ /* .ne0 =*/ ne0,
1390
+ /* .ne1 =*/ ne1,
1391
+ /* .ne2 =*/ ne2,
1392
+ /* .ne3 =*/ ne3,
1393
+ /* .nb0 =*/ nb0,
1394
+ /* .nb1 =*/ nb1,
1395
+ /* .nb2 =*/ nb2,
1396
+ /* .nb3 =*/ nb3,
1397
+ };
1398
+
1380
1399
[encoder setComputePipelineState: pipeline];
1381
- [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
1382
- [encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
1383
- [encoder setBytes: &ne00 length: sizeof ( int64_t ) atIndex: 2 ];
1384
- [encoder setBytes: &ne01 length: sizeof ( int64_t ) atIndex: 3 ];
1385
- [encoder setBytes: &ne02 length: sizeof ( int64_t ) atIndex: 4 ];
1386
- [encoder setBytes: &ne03 length: sizeof ( int64_t ) atIndex: 5 ];
1387
- [encoder setBytes: &nb00 length: sizeof (uint64_t ) atIndex: 6 ];
1388
- [encoder setBytes: &nb01 length: sizeof (uint64_t ) atIndex: 7 ];
1389
- [encoder setBytes: &nb02 length: sizeof (uint64_t ) atIndex: 8 ];
1390
- [encoder setBytes: &nb03 length: sizeof (uint64_t ) atIndex: 9 ];
1391
- [encoder setBytes: &ne0 length: sizeof ( int64_t ) atIndex: 10 ];
1392
- [encoder setBytes: &ne1 length: sizeof ( int64_t ) atIndex: 11 ];
1393
- [encoder setBytes: &ne2 length: sizeof ( int64_t ) atIndex: 12 ];
1394
- [encoder setBytes: &ne3 length: sizeof ( int64_t ) atIndex: 13 ];
1395
- [encoder setBytes: &nb0 length: sizeof (uint64_t ) atIndex: 14 ];
1396
- [encoder setBytes: &nb1 length: sizeof (uint64_t ) atIndex: 15 ];
1397
- [encoder setBytes: &nb2 length: sizeof (uint64_t ) atIndex: 16 ];
1398
- [encoder setBytes: &nb3 length: sizeof (uint64_t ) atIndex: 17 ];
1400
+ [encoder setBytes: &args length: sizeof (args) atIndex: 0 ];
1401
+ [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 1 ];
1402
+ [encoder setBuffer: id_dst offset: offs_dst atIndex: 2 ];
1399
1403
1400
1404
const int nth = MIN ((int ) pipeline.maxTotalThreadsPerThreadgroup , ne00);
1401
1405
@@ -3425,25 +3429,29 @@ static void ggml_metal_encode_node(
3425
3429
default : GGML_ABORT (" not implemented" );
3426
3430
}
3427
3431
3432
+ ggml_metal_kargs_cpy args = {
3433
+ /* .ne00 =*/ ne00,
3434
+ /* .ne01 =*/ ne01,
3435
+ /* .ne02 =*/ ne02,
3436
+ /* .ne03 =*/ ne03,
3437
+ /* .nb00 =*/ nb00,
3438
+ /* .nb01 =*/ nb01,
3439
+ /* .nb02 =*/ nb02,
3440
+ /* .nb03 =*/ nb03,
3441
+ /* .ne0 =*/ ne0,
3442
+ /* .ne1 =*/ ne1,
3443
+ /* .ne2 =*/ ne2,
3444
+ /* .ne3 =*/ ne3,
3445
+ /* .nb0 =*/ nb0,
3446
+ /* .nb1 =*/ nb1,
3447
+ /* .nb2 =*/ nb2,
3448
+ /* .nb3 =*/ nb3,
3449
+ };
3450
+
3428
3451
[encoder setComputePipelineState: pipeline];
3429
- [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
3430
- [encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
3431
- [encoder setBytes: &ne00 length: sizeof ( int64_t ) atIndex: 2 ];
3432
- [encoder setBytes: &ne01 length: sizeof ( int64_t ) atIndex: 3 ];
3433
- [encoder setBytes: &ne02 length: sizeof ( int64_t ) atIndex: 4 ];
3434
- [encoder setBytes: &ne03 length: sizeof ( int64_t ) atIndex: 5 ];
3435
- [encoder setBytes: &nb00 length: sizeof (uint64_t ) atIndex: 6 ];
3436
- [encoder setBytes: &nb01 length: sizeof (uint64_t ) atIndex: 7 ];
3437
- [encoder setBytes: &nb02 length: sizeof (uint64_t ) atIndex: 8 ];
3438
- [encoder setBytes: &nb03 length: sizeof (uint64_t ) atIndex: 9 ];
3439
- [encoder setBytes: &ne0 length: sizeof ( int64_t ) atIndex: 10 ];
3440
- [encoder setBytes: &ne1 length: sizeof ( int64_t ) atIndex: 11 ];
3441
- [encoder setBytes: &ne2 length: sizeof ( int64_t ) atIndex: 12 ];
3442
- [encoder setBytes: &ne3 length: sizeof ( int64_t ) atIndex: 13 ];
3443
- [encoder setBytes: &nb0 length: sizeof (uint64_t ) atIndex: 14 ];
3444
- [encoder setBytes: &nb1 length: sizeof (uint64_t ) atIndex: 15 ];
3445
- [encoder setBytes: &nb2 length: sizeof (uint64_t ) atIndex: 16 ];
3446
- [encoder setBytes: &nb3 length: sizeof (uint64_t ) atIndex: 17 ];
3452
+ [encoder setBytes: &args length: sizeof (args) atIndex: 0 ];
3453
+ [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 1 ];
3454
+ [encoder setBuffer: id_dst offset: offs_dst atIndex: 2 ];
3447
3455
3448
3456
[encoder dispatchThreadgroups: MTLSizeMake (ne01, ne02, ne03) threadsPerThreadgroup: MTLSizeMake (nth, 1 , 1 )];
3449
3457
} break ;
0 commit comments