Skip to content

Commit 83a00ce

Browse files
authored
metal : support bcast add & dup & cont op (#2323)
1 parent d2a4366 commit 83a00ce

File tree

2 files changed

+22
-1
lines changed

2 files changed

+22
-1
lines changed

ggml-metal.m

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
id<MTLComputePipelineState> pipeline_##name
4343

4444
GGML_METAL_DECL_KERNEL(add);
45+
GGML_METAL_DECL_KERNEL(add_row); // TODO: avoid this extra kernel, instead extend the "add" kernel to support broadcast
4546
GGML_METAL_DECL_KERNEL(mul);
4647
GGML_METAL_DECL_KERNEL(mul_row); // TODO: avoid this extra kernel, instead extend the "mul" kernel to support broadcast
4748
GGML_METAL_DECL_KERNEL(scale);
@@ -157,6 +158,7 @@ @implementation GGMLMetalClass
157158
fprintf(stderr, "%s: loaded %-32s %16p\n", __func__, "kernel_"#name, (void *) ctx->pipeline_##name);
158159

159160
GGML_METAL_ADD_KERNEL(add);
161+
GGML_METAL_ADD_KERNEL(add_row);
160162
GGML_METAL_ADD_KERNEL(mul);
161163
GGML_METAL_ADD_KERNEL(mul_row);
162164
GGML_METAL_ADD_KERNEL(scale);
@@ -464,10 +466,16 @@ void ggml_metal_graph_compute(
464466
encoder = [command_buffer computeCommandEncoder];
465467
}
466468

467-
[encoder setComputePipelineState:ctx->pipeline_add];
469+
if (ggml_nelements(src1) == ne10) {
470+
// src1 is a row
471+
[encoder setComputePipelineState:ctx->pipeline_add_row];
472+
} else {
473+
[encoder setComputePipelineState:ctx->pipeline_add];
474+
}
468475
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
469476
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
470477
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
478+
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
471479

472480
const int64_t n = ggml_nelements(dst);
473481

@@ -919,7 +927,9 @@ void ggml_metal_graph_compute(
919927

920928
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
921929
} break;
930+
case GGML_OP_DUP:
922931
case GGML_OP_CPY:
932+
case GGML_OP_CONT:
923933
{
924934
if (encoder == nil) {
925935
encoder = [command_buffer computeCommandEncoder];

ggml-metal.metal

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,17 @@ kernel void kernel_add(
6767
dst[tpig] = src0[tpig] + src1[tpig];
6868
}
6969

70+
// assumption: src1 is a row
71+
// broadcast src1 into src0
72+
kernel void kernel_add_row(
73+
device const float * src0,
74+
device const float * src1,
75+
device float * dst,
76+
constant int64_t & ne00,
77+
uint tpig[[thread_position_in_grid]]) {
78+
dst[tpig] = src0[tpig] + src1[tpig % ne00];
79+
}
80+
7081
kernel void kernel_mul(
7182
device const float * src0,
7283
device const float * src1,

0 commit comments

Comments
 (0)