|
41 | 41 |
|
42 | 42 | id<MTLFunction> function_mul_mat_q4_0;
|
43 | 43 | id<MTLComputePipelineState> pipeline_mul_mat_q4_0;
|
| 44 | + |
| 45 | + id<MTLFunction> function_rope; |
| 46 | + id<MTLComputePipelineState> pipeline_rope; |
44 | 47 | };
|
45 | 48 |
|
46 | 49 | // MSL code
|
|
148 | 151 | ctx->function_mul_mat_q4_0 = [ctx->library newFunctionWithName:@"kernel_mul_mat_q4_0"];
|
149 | 152 | ctx->pipeline_mul_mat_q4_0 = [ctx->device newComputePipelineStateWithFunction:ctx->function_mul_mat_q4_0 error:nil];
|
150 | 153 | fprintf(stderr, "%s: loaded kernel_mul_mat_q4_0: %p\n", __func__, (void *) ctx->pipeline_mul_mat_q4_0);
|
| 154 | + |
| 155 | + ctx->function_rope = [ctx->library newFunctionWithName:@"kernel_rope"]; |
| 156 | + ctx->pipeline_rope = [ctx->device newComputePipelineStateWithFunction:ctx->function_rope error:nil]; |
| 157 | + fprintf(stderr, "%s: loaded kernel_rope: %p\n", __func__, (void *) ctx->pipeline_rope); |
151 | 158 | }
|
152 | 159 |
|
153 | 160 | // MTLBuffer approach
|
@@ -250,6 +257,10 @@ int llama_mtl_eval(
|
250 | 257 | fprintf(stderr, "%s: encoding node %3d, op = %8s\n", __func__, i, ggml_op_name(gf->nodes[i]->op));
|
251 | 258 |
|
252 | 259 | switch (gf->nodes[i]->op) {
|
| 260 | + case GGML_OP_RESHAPE: |
| 261 | + { |
| 262 | + // noop |
| 263 | + } break; |
253 | 264 | case GGML_OP_ADD:
|
254 | 265 | {
|
255 | 266 | if (encoder == nil) {
|
@@ -453,6 +464,68 @@ int llama_mtl_eval(
|
453 | 464 |
|
454 | 465 | [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
455 | 466 | } break;
|
| 467 | + case GGML_OP_ROPE: |
| 468 | + { |
| 469 | + if (encoder == nil) { |
| 470 | + encoder = [command_buffer computeCommandEncoder]; |
| 471 | + } |
| 472 | + |
| 473 | + id<MTLBuffer> id_src0 = llama_mtl_get_buffer(ctx, gf->nodes[i]->src0, &offs_src0); |
| 474 | + id<MTLBuffer> id_dst = llama_mtl_get_buffer(ctx, gf->nodes[i], &offs_dst); |
| 475 | + |
| 476 | + const int64_t ne00 = gf->nodes[i]->src0->ne[0]; |
| 477 | + const int64_t ne01 = gf->nodes[i]->src0->ne[1]; |
| 478 | + const int64_t ne02 = gf->nodes[i]->src0->ne[2]; |
| 479 | + const int64_t ne03 = gf->nodes[i]->src0->ne[3]; |
| 480 | + |
| 481 | + const uint64_t nb00 = gf->nodes[i]->src0->nb[0]; |
| 482 | + const uint64_t nb01 = gf->nodes[i]->src0->nb[1]; |
| 483 | + const uint64_t nb02 = gf->nodes[i]->src0->nb[2]; |
| 484 | + const uint64_t nb03 = gf->nodes[i]->src0->nb[3]; |
| 485 | + |
| 486 | + const int64_t ne0 = gf->nodes[i]->ne[0]; |
| 487 | + const int64_t ne1 = gf->nodes[i]->ne[1]; |
| 488 | + const int64_t ne2 = gf->nodes[i]->ne[2]; |
| 489 | + const int64_t ne3 = gf->nodes[i]->ne[3]; |
| 490 | + |
| 491 | + const uint64_t nb0 = gf->nodes[i]->nb[0]; |
| 492 | + const uint64_t nb1 = gf->nodes[i]->nb[1]; |
| 493 | + const uint64_t nb2 = gf->nodes[i]->nb[2]; |
| 494 | + const uint64_t nb3 = gf->nodes[i]->nb[3]; |
| 495 | + |
| 496 | + const int n_past = ((int32_t *) gf->nodes[i]->src1->data)[0]; // TODO: TMP !!!!! |
| 497 | + const int n_dims = ((int32_t *) gf->nodes[i]->src1->data)[1]; |
| 498 | + const int mode = ((int32_t *) gf->nodes[i]->src1->data)[2]; |
| 499 | + |
| 500 | + printf("rope: %lld x %lld x %lld x %lld\n", ne00, ne01, ne02, ne03); |
| 501 | + printf("rope: %lld x %lld x %lld x %lld\n", ne0, ne1, ne2, ne3); |
| 502 | + printf("rope: n_past = %d, n_dims = %d, mode = %d\n", n_past, n_dims, mode); |
| 503 | + |
| 504 | + [encoder setComputePipelineState:ctx->pipeline_rope]; |
| 505 | + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; |
| 506 | + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; |
| 507 | + [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; |
| 508 | + [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3]; |
| 509 | + [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4]; |
| 510 | + [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5]; |
| 511 | + [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6]; |
| 512 | + [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7]; |
| 513 | + [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8]; |
| 514 | + [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9]; |
| 515 | + [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10]; |
| 516 | + [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11]; |
| 517 | + [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12]; |
| 518 | + [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13]; |
| 519 | + [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14]; |
| 520 | + [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15]; |
| 521 | + [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16]; |
| 522 | + [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17]; |
| 523 | + [encoder setBytes:&n_past length:sizeof( int) atIndex:18]; |
| 524 | + [encoder setBytes:&n_dims length:sizeof( int) atIndex:19]; |
| 525 | + [encoder setBytes:&mode length:sizeof( int) atIndex:20]; |
| 526 | + |
| 527 | + [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; |
| 528 | + } break; |
456 | 529 | default:
|
457 | 530 | fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(gf->nodes[i]->op));
|
458 | 531 | GGML_ASSERT(false);
|
@@ -486,7 +559,7 @@ int llama_mtl_eval(
|
486 | 559 |
|
487 | 560 | {
|
488 | 561 | const double time_elapsed = [command_buffer GPUEndTime] - [command_buffer GPUStartTime];
|
489 |
| - fprintf(stderr, "%s: time elapsed = %f\n", __func__, time_elapsed); |
| 562 | + fprintf(stderr, "%s: time elapsed = %f ms\n", __func__, time_elapsed * 1000.0); |
490 | 563 | }
|
491 | 564 |
|
492 | 565 | // TODO
|
|
0 commit comments