|
42 | 42 | id<MTLComputePipelineState> pipeline_##name
|
43 | 43 |
|
44 | 44 | GGML_METAL_DECL_KERNEL(add);
|
| 45 | + GGML_METAL_DECL_KERNEL(add_row); // TODO: avoid this extra kernel, instead extend the "add" kernel to support broadcast |
45 | 46 | GGML_METAL_DECL_KERNEL(mul);
|
46 | 47 | GGML_METAL_DECL_KERNEL(mul_row); // TODO: avoid this extra kernel, instead extend the "mul" kernel to support broadcast
|
47 | 48 | GGML_METAL_DECL_KERNEL(scale);
|
@@ -157,6 +158,7 @@ @implementation GGMLMetalClass
|
157 | 158 | fprintf(stderr, "%s: loaded %-32s %16p\n", __func__, "kernel_"#name, (void *) ctx->pipeline_##name);
|
158 | 159 |
|
159 | 160 | GGML_METAL_ADD_KERNEL(add);
|
| 161 | + GGML_METAL_ADD_KERNEL(add_row); |
160 | 162 | GGML_METAL_ADD_KERNEL(mul);
|
161 | 163 | GGML_METAL_ADD_KERNEL(mul_row);
|
162 | 164 | GGML_METAL_ADD_KERNEL(scale);
|
@@ -464,10 +466,16 @@ void ggml_metal_graph_compute(
|
464 | 466 | encoder = [command_buffer computeCommandEncoder];
|
465 | 467 | }
|
466 | 468 |
|
467 |
| - [encoder setComputePipelineState:ctx->pipeline_add]; |
| 469 | + if (ggml_nelements(src1) == ne10) { |
| 470 | + // src1 is a row |
| 471 | + [encoder setComputePipelineState:ctx->pipeline_add_row]; |
| 472 | + } else { |
| 473 | + [encoder setComputePipelineState:ctx->pipeline_add]; |
| 474 | + } |
468 | 475 | [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
469 | 476 | [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
470 | 477 | [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
| 478 | + [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; |
471 | 479 |
|
472 | 480 | const int64_t n = ggml_nelements(dst);
|
473 | 481 |
|
@@ -919,7 +927,9 @@ void ggml_metal_graph_compute(
|
919 | 927 |
|
920 | 928 | [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
921 | 929 | } break;
|
| 930 | + case GGML_OP_DUP: |
922 | 931 | case GGML_OP_CPY:
|
| 932 | + case GGML_OP_CONT: |
923 | 933 | {
|
924 | 934 | if (encoder == nil) {
|
925 | 935 | encoder = [command_buffer computeCommandEncoder];
|
|
0 commit comments