Skip to content

Commit e418ccf

Browse files
committed
metal : GGML_OP_CONCAT
ggml-ci
1 parent 5d4cbc0 commit e418ccf

File tree

3 files changed

+75
-67
lines changed

3 files changed

+75
-67
lines changed

ggml/src/ggml-common.h

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -419,6 +419,34 @@ typedef struct {
419419
static_assert(sizeof(block_iq4_xs) == sizeof(ggml_half) + sizeof(uint16_t) + QK_K/64 + QK_K/2, "wrong iq4_xs block size/padding");
420420

421421
#if defined(GGML_COMMON_DECL_METAL_KARGS)
422+
typedef struct {
423+
int32_t ne00;
424+
int32_t ne01;
425+
int32_t ne02;
426+
int32_t ne03;
427+
uint64_t nb00;
428+
uint64_t nb01;
429+
uint64_t nb02;
430+
uint64_t nb03;
431+
int32_t ne10;
432+
int32_t ne11;
433+
int32_t ne12;
434+
int32_t ne13;
435+
uint64_t nb10;
436+
uint64_t nb11;
437+
uint64_t nb12;
438+
uint64_t nb13;
439+
int32_t ne0;
440+
int32_t ne1;
441+
int32_t ne2;
442+
int32_t ne3;
443+
uint64_t nb0;
444+
uint64_t nb1;
445+
uint64_t nb2;
446+
uint64_t nb3;
447+
int32_t dim;
448+
} ggml_metal_kargs_concat;
449+
422450
typedef struct {
423451
int32_t ne00;
424452
int32_t ne01;

ggml/src/ggml-metal.m

Lines changed: 32 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1193,35 +1193,39 @@ static void ggml_metal_encode_node(
11931193

11941194
const int32_t dim = ((const int32_t *) dst->op_params)[0];
11951195

1196+
ggml_metal_kargs_concat args = {
1197+
/*.ne00 =*/ ne00,
1198+
/*.ne01 =*/ ne01,
1199+
/*.ne02 =*/ ne02,
1200+
/*.ne03 =*/ ne03,
1201+
/*.nb00 =*/ nb00,
1202+
/*.nb01 =*/ nb01,
1203+
/*.nb02 =*/ nb02,
1204+
/*.nb03 =*/ nb03,
1205+
/*.ne10 =*/ ne10,
1206+
/*.ne11 =*/ ne11,
1207+
/*.ne12 =*/ ne12,
1208+
/*.ne13 =*/ ne13,
1209+
/*.nb10 =*/ nb10,
1210+
/*.nb11 =*/ nb11,
1211+
/*.nb12 =*/ nb12,
1212+
/*.nb13 =*/ nb13,
1213+
/*.ne0 =*/ ne0,
1214+
/*.ne1 =*/ ne1,
1215+
/*.ne2 =*/ ne2,
1216+
/*.ne3 =*/ ne3,
1217+
/*.nb0 =*/ nb0,
1218+
/*.nb1 =*/ nb1,
1219+
/*.nb2 =*/ nb2,
1220+
/*.nb3 =*/ nb3,
1221+
/*.dim =*/ dim,
1222+
};
1223+
11961224
[encoder setComputePipelineState:pipeline];
1197-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1198-
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1199-
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1200-
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
1201-
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
1202-
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
1203-
[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
1204-
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
1205-
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8];
1206-
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9];
1207-
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10];
1208-
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
1209-
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
1210-
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
1211-
[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
1212-
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
1213-
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
1214-
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
1215-
[encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
1216-
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19];
1217-
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20];
1218-
[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21];
1219-
[encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22];
1220-
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23];
1221-
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24];
1222-
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25];
1223-
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26];
1224-
[encoder setBytes:&dim length:sizeof(dim) atIndex:27];
1225+
[encoder setBytes:&args length:sizeof(args) atIndex:0];
1226+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
1227+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
1228+
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];
12251229

12261230
const int nth = MIN(1024, ne0);
12271231

ggml/src/ggml-metal.metal

Lines changed: 15 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1893,7 +1893,7 @@ void kernel_mul_mv_impl(
18931893

18941894
float sumf = 0;
18951895
for (int i = tiisg; i < args.ne00/4; i += 32) {
1896-
sumf += dot((T14) x4[i], y4[i]);
1896+
sumf += dot((float4) x4[i], (float4) y4[i]);
18971897
}
18981898

18991899
float all_sum = simd_sum(sumf);
@@ -3876,55 +3876,31 @@ kernel void kernel_cpy_f32_iq4_nl(
38763876
}
38773877

38783878
kernel void kernel_concat(
3879+
constant ggml_metal_kargs_concat & args,
38793880
device const char * src0,
38803881
device const char * src1,
38813882
device char * dst,
3882-
constant int64_t & ne00,
3883-
constant int64_t & ne01,
3884-
constant int64_t & ne02,
3885-
constant int64_t & ne03,
3886-
constant uint64_t & nb00,
3887-
constant uint64_t & nb01,
3888-
constant uint64_t & nb02,
3889-
constant uint64_t & nb03,
3890-
constant int64_t & ne10,
3891-
constant int64_t & ne11,
3892-
constant int64_t & ne12,
3893-
constant int64_t & ne13,
3894-
constant uint64_t & nb10,
3895-
constant uint64_t & nb11,
3896-
constant uint64_t & nb12,
3897-
constant uint64_t & nb13,
3898-
constant int64_t & ne0,
3899-
constant int64_t & ne1,
3900-
constant int64_t & ne2,
3901-
constant int64_t & ne3,
3902-
constant uint64_t & nb0,
3903-
constant uint64_t & nb1,
3904-
constant uint64_t & nb2,
3905-
constant uint64_t & nb3,
3906-
constant int32_t & dim,
3907-
uint3 tgpig[[threadgroup_position_in_grid]],
3908-
uint3 tpitg[[thread_position_in_threadgroup]],
3909-
uint3 ntg[[threads_per_threadgroup]]) {
3883+
uint3 tgpig[[threadgroup_position_in_grid]],
3884+
ushort3 tpitg[[thread_position_in_threadgroup]],
3885+
ushort3 ntg[[threads_per_threadgroup]]) {
39103886

3911-
const int64_t i3 = tgpig.z;
3912-
const int64_t i2 = tgpig.y;
3913-
const int64_t i1 = tgpig.x;
3887+
const int i3 = tgpig.z;
3888+
const int i2 = tgpig.y;
3889+
const int i1 = tgpig.x;
39143890

3915-
int64_t o[4] = {0, 0, 0, 0};
3916-
o[dim] = dim == 0 ? ne00 : (dim == 1 ? ne01 : (dim == 2 ? ne02 : ne03));
3891+
int o[4] = {0, 0, 0, 0};
3892+
o[args.dim] = args.dim == 0 ? args.ne00 : (args.dim == 1 ? args.ne01 : (args.dim == 2 ? args.ne02 : args.ne03));
39173893

39183894
device const float * x;
39193895

3920-
for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
3921-
if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
3922-
x = (device const float *)(src0 + (i3 )*nb03 + (i2 )*nb02 + (i1 )*nb01 + (i0 )*nb00);
3896+
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
3897+
if (i0 < args.ne00 && i1 < args.ne01 && i2 < args.ne02 && i3 < args.ne03) {
3898+
x = (device const float *)(src0 + (i3 )*args.nb03 + (i2 )*args.nb02 + (i1 )*args.nb01 + (i0 )*args.nb00);
39233899
} else {
3924-
x = (device const float *)(src1 + (i3 - o[3])*nb13 + (i2 - o[2])*nb12 + (i1 - o[1])*nb11 + (i0 - o[0])*nb10);
3900+
x = (device const float *)(src1 + (i3 - o[3])*args.nb13 + (i2 - o[2])*args.nb12 + (i1 - o[1])*args.nb11 + (i0 - o[0])*args.nb10);
39253901
}
39263902

3927-
device float * y = (device float *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
3903+
device float * y = (device float *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
39283904

39293905
*y = *x;
39303906
}

0 commit comments

Comments
 (0)