|
239 | 239 | struct ggml_cgraph * gf;
|
240 | 240 |
|
241 | 241 | // the callback given to the thread pool
|
242 |
| - // TODO: ideally, this should be created once, utilizing the command buffer state above |
243 |
| - // for some reason, doing it like this leads to a crash |
244 | 242 | void (^encode_async)(size_t ith);
|
245 | 243 |
|
246 | 244 | // n_cb command buffers + 1 used by the main thread
|
@@ -683,6 +681,8 @@ static void ggml_metal_free(struct ggml_backend_metal_context * ctx) {
|
683 | 681 | [ctx->kernels[i].pipeline release];
|
684 | 682 | }
|
685 | 683 |
|
| 684 | + Block_release(ctx->encode_async); |
| 685 | + |
686 | 686 | [ctx->queue release];
|
687 | 687 | [ctx->device release];
|
688 | 688 |
|
@@ -3000,46 +3000,6 @@ static enum ggml_status ggml_metal_graph_compute(
|
3000 | 3000 | }
|
3001 | 3001 | }
|
3002 | 3002 |
|
3003 |
| - // TODO: how to avoid this allocation? I tried initializing it in ggml_backend_metal_set_n_cb but it crashes. |
3004 |
| - ctx->encode_async = ^(size_t iter) { |
3005 |
| - const int cb_idx = iter; |
3006 |
| - const int n_cb_l = ctx->n_cb; |
3007 |
| - |
3008 |
| - const int n_nodes_0 = ctx->n_nodes_0; |
3009 |
| - const int n_nodes_1 = ctx->n_nodes_1; |
3010 |
| - |
3011 |
| - const int n_nodes_per_cb = ctx->n_nodes_per_cb; |
3012 |
| - |
3013 |
| - id<MTLCommandBuffer> command_buffer = ctx->command_buffers[cb_idx]; |
3014 |
| - id<MTLComputeCommandEncoder> encoder = [command_buffer computeCommandEncoder]; |
3015 |
| - |
3016 |
| - int node_start = 0; |
3017 |
| - int node_end = n_nodes_0; |
3018 |
| - |
3019 |
| - if (cb_idx < n_cb_l) { |
3020 |
| - node_start = n_nodes_0 + ( (cb_idx + 0) * n_nodes_per_cb); |
3021 |
| - node_end = n_nodes_0 + (MIN((cb_idx == n_cb_l - 1) ? n_nodes_1 : (cb_idx + 1) * n_nodes_per_cb, n_nodes_1)); |
3022 |
| - } |
3023 |
| - |
3024 |
| - for (int idx = node_start; idx < node_end; ++idx) { |
3025 |
| - if (should_capture) { |
3026 |
| - [encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(ggml_graph_node(gf, idx)) encoding:NSUTF8StringEncoding]]; |
3027 |
| - } |
3028 |
| - |
3029 |
| - ggml_metal_encode_node(ctx, idx, encoder); |
3030 |
| - |
3031 |
| - if (should_capture) { |
3032 |
| - [encoder popDebugGroup]; |
3033 |
| - } |
3034 |
| - } |
3035 |
| - |
3036 |
| - [encoder endEncoding]; |
3037 |
| - |
3038 |
| - if (cb_idx < 2 || ctx->abort_callback == NULL) { |
3039 |
| - [command_buffer commit]; |
3040 |
| - } |
3041 |
| - }; |
3042 |
| - |
3043 | 3003 | // the main thread commits the first few commands immediately
|
3044 | 3004 | // command_buffer[n_cb]
|
3045 | 3005 | {
|
@@ -3468,10 +3428,50 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
|
3468 | 3428 | }
|
3469 | 3429 | }
|
3470 | 3430 |
|
3471 |
| - // TODO: setting encode_async here causes crash during the next ggml_metal_graph_compute call. why? |
3472 |
| - //ctx->encode_async = ^(size_t iter) { |
3473 |
| - // ... |
3474 |
| - //}; |
| 3431 | + if (ctx->encode_async) { |
| 3432 | + Block_release(ctx->encode_async); |
| 3433 | + } |
| 3434 | + |
| 3435 | + ctx->encode_async = Block_copy(^(size_t iter) { |
| 3436 | + const int cb_idx = iter; |
| 3437 | + const int n_cb_l = ctx->n_cb; |
| 3438 | + |
| 3439 | + const int n_nodes_0 = ctx->n_nodes_0; |
| 3440 | + const int n_nodes_1 = ctx->n_nodes_1; |
| 3441 | + |
| 3442 | + const int n_nodes_per_cb = ctx->n_nodes_per_cb; |
| 3443 | + |
| 3444 | + id<MTLCommandBuffer> command_buffer = ctx->command_buffers[cb_idx]; |
| 3445 | + id<MTLComputeCommandEncoder> encoder = [command_buffer computeCommandEncoder]; |
| 3446 | + |
| 3447 | + int node_start = 0; |
| 3448 | + int node_end = n_nodes_0; |
| 3449 | + |
| 3450 | + if (cb_idx < n_cb_l) { |
| 3451 | + node_start = n_nodes_0 + ( (cb_idx + 0) * n_nodes_per_cb); |
| 3452 | + node_end = n_nodes_0 + (MIN((cb_idx == n_cb_l - 1) ? n_nodes_1 : (cb_idx + 1) * n_nodes_per_cb, n_nodes_1)); |
| 3453 | + } |
| 3454 | + |
| 3455 | + const bool should_capture = ctx->capture_next_compute; |
| 3456 | + |
| 3457 | + for (int idx = node_start; idx < node_end; ++idx) { |
| 3458 | + if (should_capture) { |
| 3459 | + [encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(ggml_graph_node(ctx->gf, idx)) encoding:NSUTF8StringEncoding]]; |
| 3460 | + } |
| 3461 | + |
| 3462 | + ggml_metal_encode_node(ctx, idx, encoder); |
| 3463 | + |
| 3464 | + if (should_capture) { |
| 3465 | + [encoder popDebugGroup]; |
| 3466 | + } |
| 3467 | + } |
| 3468 | + |
| 3469 | + [encoder endEncoding]; |
| 3470 | + |
| 3471 | + if (cb_idx < 2 || ctx->abort_callback == NULL) { |
| 3472 | + [command_buffer commit]; |
| 3473 | + } |
| 3474 | + }); |
3475 | 3475 | }
|
3476 | 3476 |
|
3477 | 3477 | static struct ggml_backend_i ggml_backend_metal_i = {
|
|
0 commit comments