Skip to content

Commit 64afc0b

Browse files
committed
mtl : add mul kernel + confirm working
1 parent 72256eb commit 64afc0b

File tree

3 files changed

+44
-4
lines changed

3 files changed

+44
-4
lines changed

examples/mtl/mtl.m

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@
2424
id<MTLFunction> function_add;
2525
id<MTLComputePipelineState> pipeline_add;
2626

27+
id<MTLFunction> function_mul;
28+
id<MTLComputePipelineState> pipeline_mul;
29+
2730
id<MTLFunction> function_relu;
2831
id<MTLComputePipelineState> pipeline_relu;
2932

@@ -119,6 +122,10 @@
119122
ctx->pipeline_add = [ctx->device newComputePipelineStateWithFunction:ctx->function_add error:nil];
120123
fprintf(stderr, "%s: loaded kernel_add: %p\n", __func__, (void *) ctx->pipeline_add);
121124

125+
ctx->function_mul = [ctx->library newFunctionWithName:@"kernel_mul"];
126+
ctx->pipeline_mul = [ctx->device newComputePipelineStateWithFunction:ctx->function_mul error:nil];
127+
fprintf(stderr, "%s: loaded kernel_mul: %p\n", __func__, (void *) ctx->pipeline_mul);
128+
122129
ctx->function_relu = [ctx->library newFunctionWithName:@"kernel_relu"];
123130
ctx->pipeline_relu = [ctx->device newComputePipelineStateWithFunction:ctx->function_relu error:nil];
124131
fprintf(stderr, "%s: loaded kernel_relu: %p\n", __func__, (void *) ctx->pipeline_relu);
@@ -253,6 +260,28 @@ int llama_mtl_eval(
253260

254261
const int64_t n = ggml_nelements(gf->nodes[i]);
255262

263+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
264+
} break;
265+
case GGML_OP_MUL:
266+
{
267+
if (encoder == nil) {
268+
encoder = [command_buffer computeCommandEncoder];
269+
}
270+
271+
id<MTLBuffer> id_src0 = llama_mtl_get_buffer(ctx, gf->nodes[i]->src0, &offs_src0);
272+
id<MTLBuffer> id_src1 = llama_mtl_get_buffer(ctx, gf->nodes[i]->src1, &offs_src1);
273+
id<MTLBuffer> id_dst = llama_mtl_get_buffer(ctx, gf->nodes[i], &offs_dst);
274+
275+
const int64_t ne00 = gf->nodes[i]->src0->ne[0];
276+
277+
[encoder setComputePipelineState:ctx->pipeline_mul];
278+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
279+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
280+
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
281+
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
282+
283+
const int64_t n = ggml_nelements(gf->nodes[i]);
284+
256285
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
257286
} break;
258287
case GGML_OP_RELU:
@@ -373,7 +402,7 @@ int llama_mtl_eval(
373402
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
374403
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
375404
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
376-
[encoder setBytes:&eps length:sizeof( float) atIndex:4];
405+
[encoder setBytes:&eps length:sizeof( float) atIndex:4];
377406

378407
const int64_t nrows = ggml_nrows(gf->nodes[i]->src0);
379408

examples/mtl/mtl.metal

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,17 @@ kernel void kernel_add(
4242
dst[gid] = src0[gid] + src1[gid];
4343
}
4444

45+
// assumption: src1 is a row
46+
// broadcast src1 into src0
47+
kernel void kernel_mul(
48+
device const float * src0,
49+
device const float * src1,
50+
device float * dst,
51+
constant int64_t & ne00,
52+
uint gid[[thread_position_in_grid]]) {
53+
dst[gid] = src0[gid] * src1[gid % ne00];
54+
}
55+
4556
kernel void kernel_relu(
4657
device const float * src0,
4758
device float * dst,

llama.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1263,13 +1263,13 @@ static bool llama_eval_internal(
12631263
// norm
12641264
{
12651265
cur = ggml_rms_norm(ctx0, inpL);
1266+
1267+
// cur = cur*attention_norm(broadcasted)
1268+
cur = ggml_mul(ctx0, cur, model.layers[il].attention_norm);
12661269
// TODO: TMP !!!!
12671270
if (il == 0) {
12681271
ggml_set_name(cur, "mtl-check");
12691272
}
1270-
1271-
// cur = cur*attention_norm(broadcasted)
1272-
cur = ggml_mul(ctx0, cur, model.layers[il].attention_norm);
12731273
}
12741274

12751275
// self-attention

0 commit comments

Comments
 (0)