|
24 | 24 | id<MTLFunction> function_add;
|
25 | 25 | id<MTLComputePipelineState> pipeline_add;
|
26 | 26 |
|
| 27 | + id<MTLFunction> function_mul; |
| 28 | + id<MTLComputePipelineState> pipeline_mul; |
| 29 | + |
27 | 30 | id<MTLFunction> function_relu;
|
28 | 31 | id<MTLComputePipelineState> pipeline_relu;
|
29 | 32 |
|
|
119 | 122 | ctx->pipeline_add = [ctx->device newComputePipelineStateWithFunction:ctx->function_add error:nil];
|
120 | 123 | fprintf(stderr, "%s: loaded kernel_add: %p\n", __func__, (void *) ctx->pipeline_add);
|
121 | 124 |
|
| 125 | + ctx->function_mul = [ctx->library newFunctionWithName:@"kernel_mul"]; |
| 126 | + ctx->pipeline_mul = [ctx->device newComputePipelineStateWithFunction:ctx->function_mul error:nil]; |
| 127 | + fprintf(stderr, "%s: loaded kernel_mul: %p\n", __func__, (void *) ctx->pipeline_mul); |
| 128 | + |
122 | 129 | ctx->function_relu = [ctx->library newFunctionWithName:@"kernel_relu"];
|
123 | 130 | ctx->pipeline_relu = [ctx->device newComputePipelineStateWithFunction:ctx->function_relu error:nil];
|
124 | 131 | fprintf(stderr, "%s: loaded kernel_relu: %p\n", __func__, (void *) ctx->pipeline_relu);
|
@@ -253,6 +260,28 @@ int llama_mtl_eval(
|
253 | 260 |
|
254 | 261 | const int64_t n = ggml_nelements(gf->nodes[i]);
|
255 | 262 |
|
| 263 | + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; |
| 264 | + } break; |
| 265 | + case GGML_OP_MUL: |
| 266 | + { |
| 267 | + if (encoder == nil) { |
| 268 | + encoder = [command_buffer computeCommandEncoder]; |
| 269 | + } |
| 270 | + |
| 271 | + id<MTLBuffer> id_src0 = llama_mtl_get_buffer(ctx, gf->nodes[i]->src0, &offs_src0); |
| 272 | + id<MTLBuffer> id_src1 = llama_mtl_get_buffer(ctx, gf->nodes[i]->src1, &offs_src1); |
| 273 | + id<MTLBuffer> id_dst = llama_mtl_get_buffer(ctx, gf->nodes[i], &offs_dst); |
| 274 | + |
| 275 | + const int64_t ne00 = gf->nodes[i]->src0->ne[0]; |
| 276 | + |
| 277 | + [encoder setComputePipelineState:ctx->pipeline_mul]; |
| 278 | + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; |
| 279 | + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; |
| 280 | + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; |
| 281 | + [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; |
| 282 | + |
| 283 | + const int64_t n = ggml_nelements(gf->nodes[i]); |
| 284 | + |
256 | 285 | [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
257 | 286 | } break;
|
258 | 287 | case GGML_OP_RELU:
|
@@ -373,7 +402,7 @@ int llama_mtl_eval(
|
373 | 402 | [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
374 | 403 | [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
375 | 404 | [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
|
376 |
| - [encoder setBytes:&eps length:sizeof( float) atIndex:4]; |
| 405 | + [encoder setBytes:&eps length:sizeof( float) atIndex:4]; |
377 | 406 |
|
378 | 407 | const int64_t nrows = ggml_nrows(gf->nodes[i]->src0);
|
379 | 408 |
|
|
0 commit comments