Skip to content

Commit 1213af7

Browse files
committed
mtl : add rope kernel
1 parent 6af6a05 commit 1213af7

File tree

3 files changed

+145
-9
lines changed

3 files changed

+145
-9
lines changed

examples/mtl/mtl.m

Lines changed: 74 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,9 @@
4141

4242
id<MTLFunction> function_mul_mat_q4_0;
4343
id<MTLComputePipelineState> pipeline_mul_mat_q4_0;
44+
45+
id<MTLFunction> function_rope;
46+
id<MTLComputePipelineState> pipeline_rope;
4447
};
4548

4649
// MSL code
@@ -148,6 +151,10 @@
148151
ctx->function_mul_mat_q4_0 = [ctx->library newFunctionWithName:@"kernel_mul_mat_q4_0"];
149152
ctx->pipeline_mul_mat_q4_0 = [ctx->device newComputePipelineStateWithFunction:ctx->function_mul_mat_q4_0 error:nil];
150153
fprintf(stderr, "%s: loaded kernel_mul_mat_q4_0: %p\n", __func__, (void *) ctx->pipeline_mul_mat_q4_0);
154+
155+
ctx->function_rope = [ctx->library newFunctionWithName:@"kernel_rope"];
156+
ctx->pipeline_rope = [ctx->device newComputePipelineStateWithFunction:ctx->function_rope error:nil];
157+
fprintf(stderr, "%s: loaded kernel_rope: %p\n", __func__, (void *) ctx->pipeline_rope);
151158
}
152159

153160
// MTLBuffer approach
@@ -250,6 +257,10 @@ int llama_mtl_eval(
250257
fprintf(stderr, "%s: encoding node %3d, op = %8s\n", __func__, i, ggml_op_name(gf->nodes[i]->op));
251258

252259
switch (gf->nodes[i]->op) {
260+
case GGML_OP_RESHAPE:
261+
{
262+
// noop
263+
} break;
253264
case GGML_OP_ADD:
254265
{
255266
if (encoder == nil) {
@@ -453,6 +464,68 @@ int llama_mtl_eval(
453464

454465
[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
455466
} break;
467+
case GGML_OP_ROPE:
468+
{
469+
if (encoder == nil) {
470+
encoder = [command_buffer computeCommandEncoder];
471+
}
472+
473+
id<MTLBuffer> id_src0 = llama_mtl_get_buffer(ctx, gf->nodes[i]->src0, &offs_src0);
474+
id<MTLBuffer> id_dst = llama_mtl_get_buffer(ctx, gf->nodes[i], &offs_dst);
475+
476+
const int64_t ne00 = gf->nodes[i]->src0->ne[0];
477+
const int64_t ne01 = gf->nodes[i]->src0->ne[1];
478+
const int64_t ne02 = gf->nodes[i]->src0->ne[2];
479+
const int64_t ne03 = gf->nodes[i]->src0->ne[3];
480+
481+
const uint64_t nb00 = gf->nodes[i]->src0->nb[0];
482+
const uint64_t nb01 = gf->nodes[i]->src0->nb[1];
483+
const uint64_t nb02 = gf->nodes[i]->src0->nb[2];
484+
const uint64_t nb03 = gf->nodes[i]->src0->nb[3];
485+
486+
const int64_t ne0 = gf->nodes[i]->ne[0];
487+
const int64_t ne1 = gf->nodes[i]->ne[1];
488+
const int64_t ne2 = gf->nodes[i]->ne[2];
489+
const int64_t ne3 = gf->nodes[i]->ne[3];
490+
491+
const uint64_t nb0 = gf->nodes[i]->nb[0];
492+
const uint64_t nb1 = gf->nodes[i]->nb[1];
493+
const uint64_t nb2 = gf->nodes[i]->nb[2];
494+
const uint64_t nb3 = gf->nodes[i]->nb[3];
495+
496+
const int n_past = ((int32_t *) gf->nodes[i]->src1->data)[0]; // TODO: TMP !!!!!
497+
const int n_dims = ((int32_t *) gf->nodes[i]->src1->data)[1];
498+
const int mode = ((int32_t *) gf->nodes[i]->src1->data)[2];
499+
500+
printf("rope: %lld x %lld x %lld x %lld\n", ne00, ne01, ne02, ne03);
501+
printf("rope: %lld x %lld x %lld x %lld\n", ne0, ne1, ne2, ne3);
502+
printf("rope: n_past = %d, n_dims = %d, mode = %d\n", n_past, n_dims, mode);
503+
504+
[encoder setComputePipelineState:ctx->pipeline_rope];
505+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
506+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
507+
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
508+
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
509+
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
510+
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
511+
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
512+
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
513+
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
514+
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
515+
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
516+
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
517+
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
518+
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
519+
[encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
520+
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
521+
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
522+
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
523+
[encoder setBytes:&n_past length:sizeof( int) atIndex:18];
524+
[encoder setBytes:&n_dims length:sizeof( int) atIndex:19];
525+
[encoder setBytes:&mode length:sizeof( int) atIndex:20];
526+
527+
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
528+
} break;
456529
default:
457530
fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(gf->nodes[i]->op));
458531
GGML_ASSERT(false);
@@ -486,7 +559,7 @@ int llama_mtl_eval(
486559

487560
{
488561
const double time_elapsed = [command_buffer GPUEndTime] - [command_buffer GPUStartTime];
489-
fprintf(stderr, "%s: time elapsed = %f\n", __func__, time_elapsed);
562+
fprintf(stderr, "%s: time elapsed = %f ms\n", __func__, time_elapsed * 1000.0);
490563
}
491564

492565
// TODO

examples/mtl/mtl.metal

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,3 +210,58 @@ kernel void kernel_mul_mat_q4_0(
210210
dst[r1*ne0 + r0] = sum[0];
211211
}
212212
}
213+
214+
kernel void kernel_rope(
215+
device const void * src0,
216+
device float * dst,
217+
constant int64_t & ne00,
218+
constant int64_t & ne01,
219+
constant int64_t & ne02,
220+
constant int64_t & ne03,
221+
constant uint64_t & nb00,
222+
constant uint64_t & nb01,
223+
constant uint64_t & nb02,
224+
constant uint64_t & nb03,
225+
constant int64_t & ne0,
226+
constant int64_t & ne1,
227+
constant int64_t & ne2,
228+
constant int64_t & ne3,
229+
constant uint64_t & nb0,
230+
constant uint64_t & nb1,
231+
constant uint64_t & nb2,
232+
constant uint64_t & nb3,
233+
constant int & n_past,
234+
constant int & n_dims,
235+
constant int & mode,
236+
uint3 tpig[[thread_position_in_grid]]) {
237+
const int64_t i3 = tpig[2];
238+
const int64_t i2 = tpig[1];
239+
const int64_t i1 = tpig[0];
240+
241+
const bool is_neox = mode & 2;
242+
const float theta_scale = pow(10000.0, -2.0f/n_dims);
243+
244+
const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2);
245+
246+
float theta = (float)p;
247+
248+
if (!is_neox) {
249+
for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
250+
const float cos_theta = cos(theta);
251+
const float sin_theta = sin(theta);
252+
253+
theta *= theta_scale;
254+
255+
device const float * const src = (device float *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
256+
device float * dst_data = (device float *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
257+
258+
const float x0 = src[0];
259+
const float x1 = src[1];
260+
261+
dst_data[0] = x0*cos_theta - x1*sin_theta;
262+
dst_data[1] = x0*sin_theta + x1*cos_theta;
263+
}
264+
} else {
265+
// TODO: implement
266+
}
267+
}

llama.cpp

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1270,19 +1270,20 @@ static bool llama_eval_internal(
12701270

12711271
// self-attention
12721272
{
1273-
auto * x = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
1274-
// TODO: TMP !!!!
1275-
if (il == 0) {
1276-
ggml_set_name(x, "mtl-check");
1277-
}
1273+
//auto * x = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
1274+
//struct ggml_tensor * Qcur = ggml_rope_inplace(ctx0, ggml_reshape_3d(ctx0, x, n_embd/n_head, n_head, N), n_past, n_rot, 0);
12781275

12791276
// compute Q and K and RoPE them
1280-
//struct ggml_tensor * Qcur = ggml_rope_inplace(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model.layers[il].wq, cur), n_embd/n_head, n_head, N), n_past, n_rot, 0);
1281-
struct ggml_tensor * Qcur = ggml_rope_inplace(ctx0, ggml_reshape_3d(ctx0, x, n_embd/n_head, n_head, N), n_past, n_rot, 0);
1277+
struct ggml_tensor * Qcur = ggml_rope_inplace(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model.layers[il].wq, cur), n_embd/n_head, n_head, N), n_past, n_rot, 0);
12821278
struct ggml_tensor * Kcur = ggml_rope_inplace(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model.layers[il].wk, cur), n_embd/n_head, n_head, N), n_past, n_rot, 0);
12831279
ggml_set_name(Qcur, "Qcur");
12841280
ggml_set_name(Kcur, "Kcur");
12851281

1282+
// TODO: TMP !!!!
1283+
if (il == 0) {
1284+
ggml_set_name(Qcur, "mtl-check");
1285+
}
1286+
12861287
// store key and value to memory
12871288
{
12881289
// compute the transposed [N, n_embd] V matrix
@@ -1437,7 +1438,14 @@ static bool llama_eval_internal(
14371438
//ggml_graph_compute (ctx0, &gf);
14381439

14391440
// lets export a smaller graph to get things rolling -- baby steps first
1440-
ggml_build_forward_expand(&gf_export, ggml_get_tensor(ctx0, "mtl-check"));
1441+
{
1442+
struct ggml_tensor * t = ggml_get_tensor(ctx0, "mtl-check");
1443+
if (!t) {
1444+
fprintf(stderr, "%s: failed to find tensor 'mtl-check'\n", __func__);
1445+
exit(1);
1446+
}
1447+
ggml_build_forward_expand(&gf_export, t);
1448+
}
14411449

14421450
// print
14431451
{

0 commit comments

Comments
 (0)