Skip to content

Commit 96b6912

Browse files
metal : single allocation of encode_async block (#9747)
* Single allocation of encode_async block with non-ARC capture in ggml-metal.m * Moving Block_release to the deallocation code * Release encode block when re-setting encoding buffer count if needed * Update ggml/src/ggml-metal.m --------- Co-authored-by: Georgi Gerganov <[email protected]>
1 parent d5cb868 commit 96b6912

File tree

1 file changed

+46
-46
lines changed

1 file changed

+46
-46
lines changed

ggml/src/ggml-metal.m

Lines changed: 46 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -239,8 +239,6 @@
239239
struct ggml_cgraph * gf;
240240

241241
// the callback given to the thread pool
242-
// TODO: ideally, this should be created once, utilizing the command buffer state above
243-
// for some reason, doing it like this leads to a crash
244242
void (^encode_async)(size_t ith);
245243

246244
// n_cb command buffers + 1 used by the main thread
@@ -683,6 +681,8 @@ static void ggml_metal_free(struct ggml_backend_metal_context * ctx) {
683681
[ctx->kernels[i].pipeline release];
684682
}
685683

684+
Block_release(ctx->encode_async);
685+
686686
[ctx->queue release];
687687
[ctx->device release];
688688

@@ -3000,46 +3000,6 @@ static enum ggml_status ggml_metal_graph_compute(
30003000
}
30013001
}
30023002

3003-
// TODO: how to avoid this allocation? I tried initializing it in ggml_backend_metal_set_n_cb but it crashes.
3004-
ctx->encode_async = ^(size_t iter) {
3005-
const int cb_idx = iter;
3006-
const int n_cb_l = ctx->n_cb;
3007-
3008-
const int n_nodes_0 = ctx->n_nodes_0;
3009-
const int n_nodes_1 = ctx->n_nodes_1;
3010-
3011-
const int n_nodes_per_cb = ctx->n_nodes_per_cb;
3012-
3013-
id<MTLCommandBuffer> command_buffer = ctx->command_buffers[cb_idx];
3014-
id<MTLComputeCommandEncoder> encoder = [command_buffer computeCommandEncoder];
3015-
3016-
int node_start = 0;
3017-
int node_end = n_nodes_0;
3018-
3019-
if (cb_idx < n_cb_l) {
3020-
node_start = n_nodes_0 + ( (cb_idx + 0) * n_nodes_per_cb);
3021-
node_end = n_nodes_0 + (MIN((cb_idx == n_cb_l - 1) ? n_nodes_1 : (cb_idx + 1) * n_nodes_per_cb, n_nodes_1));
3022-
}
3023-
3024-
for (int idx = node_start; idx < node_end; ++idx) {
3025-
if (should_capture) {
3026-
[encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(ggml_graph_node(gf, idx)) encoding:NSUTF8StringEncoding]];
3027-
}
3028-
3029-
ggml_metal_encode_node(ctx, idx, encoder);
3030-
3031-
if (should_capture) {
3032-
[encoder popDebugGroup];
3033-
}
3034-
}
3035-
3036-
[encoder endEncoding];
3037-
3038-
if (cb_idx < 2 || ctx->abort_callback == NULL) {
3039-
[command_buffer commit];
3040-
}
3041-
};
3042-
30433003
// the main thread commits the first few commands immediately
30443004
// command_buffer[n_cb]
30453005
{
@@ -3468,10 +3428,50 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
34683428
}
34693429
}
34703430

3471-
// TODO: setting encode_async here causes crash during the next ggml_metal_graph_compute call. why?
3472-
//ctx->encode_async = ^(size_t iter) {
3473-
// ...
3474-
//};
3431+
if (ctx->encode_async) {
3432+
Block_release(ctx->encode_async);
3433+
}
3434+
3435+
ctx->encode_async = Block_copy(^(size_t iter) {
3436+
const int cb_idx = iter;
3437+
const int n_cb_l = ctx->n_cb;
3438+
3439+
const int n_nodes_0 = ctx->n_nodes_0;
3440+
const int n_nodes_1 = ctx->n_nodes_1;
3441+
3442+
const int n_nodes_per_cb = ctx->n_nodes_per_cb;
3443+
3444+
id<MTLCommandBuffer> command_buffer = ctx->command_buffers[cb_idx];
3445+
id<MTLComputeCommandEncoder> encoder = [command_buffer computeCommandEncoder];
3446+
3447+
int node_start = 0;
3448+
int node_end = n_nodes_0;
3449+
3450+
if (cb_idx < n_cb_l) {
3451+
node_start = n_nodes_0 + ( (cb_idx + 0) * n_nodes_per_cb);
3452+
node_end = n_nodes_0 + (MIN((cb_idx == n_cb_l - 1) ? n_nodes_1 : (cb_idx + 1) * n_nodes_per_cb, n_nodes_1));
3453+
}
3454+
3455+
const bool should_capture = ctx->capture_next_compute;
3456+
3457+
for (int idx = node_start; idx < node_end; ++idx) {
3458+
if (should_capture) {
3459+
[encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(ggml_graph_node(ctx->gf, idx)) encoding:NSUTF8StringEncoding]];
3460+
}
3461+
3462+
ggml_metal_encode_node(ctx, idx, encoder);
3463+
3464+
if (should_capture) {
3465+
[encoder popDebugGroup];
3466+
}
3467+
}
3468+
3469+
[encoder endEncoding];
3470+
3471+
if (cb_idx < 2 || ctx->abort_callback == NULL) {
3472+
[command_buffer commit];
3473+
}
3474+
});
34753475
}
34763476

34773477
static struct ggml_backend_i ggml_backend_metal_i = {

0 commit comments

Comments
 (0)