@@ -15773,8 +15773,6 @@ struct ggml_cgraph ggml_build_forward(struct ggml_tensor * tensor) {
15773
15773
/*.n_nodes =*/ 0,
15774
15774
/*.n_leafs =*/ 0,
15775
15775
/*.n_threads =*/ GGML_DEFAULT_N_THREADS,
15776
- /*.work_size =*/ 0,
15777
- /*.work =*/ NULL,
15778
15776
/*.nodes =*/ { NULL },
15779
15777
/*.grads =*/ { NULL },
15780
15778
/*.leafs =*/ { NULL },
@@ -15946,6 +15944,7 @@ void clear_numa_thread_affinity(void) {}
15946
15944
15947
15945
struct ggml_compute_state_shared {
15948
15946
struct ggml_cgraph * cgraph;
15947
+ struct ggml_cgraph_context * cgraph_ctx;
15949
15948
15950
15949
int64_t perf_node_start_cycles;
15951
15950
int64_t perf_node_start_time_us;
@@ -15975,6 +15974,7 @@ static void ggml_graph_compute_perf_stats_node(struct ggml_tensor * node, const
15975
15974
static thread_ret_t ggml_graph_compute_thread(void * data) {
15976
15975
struct ggml_compute_state * state = (struct ggml_compute_state *) data;
15977
15976
struct ggml_cgraph * cgraph = state->shared->cgraph;
15977
+ struct ggml_cgraph_context * ctx = state->shared->cgraph_ctx;
15978
15978
15979
15979
const int n_threads = state->shared->n_threads;
15980
15980
set_numa_thread_affinity(state->ith, n_threads);
@@ -15989,8 +15989,8 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
15989
15989
/*.type =*/ GGML_TASK_FINALIZE,
15990
15990
/*.ith =*/ 0,
15991
15991
/*.nth =*/ 0,
15992
- /*.wsize =*/ cgraph->work ? ggml_nbytes(cgraph->work) : 0 ,
15993
- /*.wdata =*/ cgraph->work ? cgraph->work->data : NULL ,
15992
+ /*.wsize =*/ ctx->work_size ,
15993
+ /*.wdata =*/ ctx->work_data ,
15994
15994
};
15995
15995
15996
15996
if (node_n != -1) {
@@ -16057,8 +16057,8 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
16057
16057
/*.type =*/ GGML_TASK_COMPUTE,
16058
16058
/*.ith =*/ state->ith,
16059
16059
/*.nth =*/ node->n_tasks,
16060
- /*.wsize =*/ cgraph->work ? ggml_nbytes(cgraph->work) : 0 ,
16061
- /*.wdata =*/ cgraph->work ? cgraph->work->data : NULL ,
16060
+ /*.wsize =*/ ctx->work_size ,
16061
+ /*.wdata =*/ ctx->work_data ,
16062
16062
};
16063
16063
16064
16064
if (state->ith < node->n_tasks) {
@@ -16069,23 +16069,20 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
16069
16069
return 0;
16070
16070
}
16071
16071
16072
- void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) {
16073
- const int n_threads = cgraph->n_threads;
16072
+ // Prepare for graph computing.
16073
+ // Will set: node->n_tasks, ctx->{work_size, planned}
16074
+ void ggml_graph_compute_plan(struct ggml_cgraph_context * ctx, struct ggml_cgraph * cgraph) {
16075
+ GGML_ASSERT(ctx);
16076
+ // This function is actually reentrant, but duplicate calls is unnecessary.
16077
+ GGML_ASSERT(ctx->work_size == 0);
16078
+ GGML_ASSERT(ctx->work_data == NULL);
16079
+ GGML_ASSERT(!ctx->planned);
16074
16080
16075
- struct ggml_compute_state_shared state_shared = {
16076
- /*.cgraph =*/ cgraph,
16077
- /*.perf_node_start_cycles =*/ 0,
16078
- /*.perf_node_start_time_us =*/ 0,
16079
- /*.n_threads =*/ n_threads,
16080
- /*.n_active =*/ n_threads,
16081
- /*.node_n =*/ -1,
16082
- };
16083
- struct ggml_compute_state * workers = alloca(sizeof(struct ggml_compute_state)*n_threads);
16081
+ int n_threads = cgraph->n_threads;
16082
+ size_t work_size = 0;
16084
16083
16085
16084
// initialize tasks + work buffer
16086
16085
{
16087
- size_t work_size = 0;
16088
-
16089
16086
// thread scheduling for the different operations
16090
16087
for (int i = 0; i < cgraph->n_nodes; i++) {
16091
16088
struct ggml_tensor * node = cgraph->nodes[i];
@@ -16399,19 +16396,53 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
16399
16396
} break;
16400
16397
}
16401
16398
}
16399
+ }
16402
16400
16403
- if (cgraph->work != NULL && work_size > cgraph->work_size) {
16404
- GGML_ASSERT(false); // TODO: better handling
16405
- }
16401
+ if (work_size > 0) {
16402
+ work_size += CACHE_LINE_SIZE*(n_threads - 1);
16403
+ }
16404
+
16405
+ ctx->work_size = work_size;
16406
+ ctx->work_data = NULL;
16407
+ ctx->planned = true;
16408
+ }
16406
16409
16407
- if (work_size > 0 && cgraph->work == NULL) {
16408
- cgraph->work_size = work_size + CACHE_LINE_SIZE*(n_threads - 1);
16410
+ void ggml_graph_compute_v2(struct ggml_cgraph_context * ctx, struct ggml_cgraph * cgraph) {
16411
+ if (ctx == NULL) {
16412
+ ctx = alloca(sizeof(struct ggml_cgraph_context));
16413
+ GGML_ASSERT(ctx);
16414
+ ctx->work_size = 0;
16415
+ ctx->work_data = NULL;
16416
+ ctx->planned = false;
16417
+ } else {
16418
+ // The work_size and work_data MAY have default values even if has been planned.
16419
+ if (ctx->work_size > 0) {
16420
+ GGML_ASSERT(ctx->work_data);
16421
+ }
16422
+ }
16409
16423
16410
- GGML_PRINT_DEBUG("%s: allocating work buffer for graph (%zu bytes)\n", __func__, cgraph->work_size);
16411
- cgraph->work = ggml_new_tensor_1d(ctx, GGML_TYPE_I8, cgraph->work_size);
16424
+ if (!ctx->planned) {
16425
+ ggml_graph_compute_plan(ctx, cgraph);
16426
+ if (ctx->work_size > 0) {
16427
+ ctx->work_data = malloc(ctx->work_size * sizeof(GGML_TYPE_I8));
16428
+ GGML_ASSERT(ctx->work_data);
16429
+ GGML_PRINT_DEBUG("%s: allocating work buffer for graph (%zu bytes)\n", __func__, work_size);
16412
16430
}
16413
16431
}
16414
16432
16433
+ const int n_threads = cgraph->n_threads;
16434
+
16435
+ struct ggml_compute_state_shared state_shared = {
16436
+ /*.cgraph =*/ cgraph,
16437
+ /*.cgraph_ctx =*/ ctx,
16438
+ /*.perf_node_start_cycles =*/ 0,
16439
+ /*.perf_node_start_time_us =*/ 0,
16440
+ /*.n_threads =*/ n_threads,
16441
+ /*.n_active =*/ n_threads,
16442
+ /*.node_n =*/ -1,
16443
+ };
16444
+ struct ggml_compute_state * workers = alloca(sizeof(struct ggml_compute_state)*n_threads);
16445
+
16415
16446
// create thread pool
16416
16447
if (n_threads > 1) {
16417
16448
for (int j = 1; j < n_threads; ++j) {
@@ -16463,6 +16494,12 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
16463
16494
}
16464
16495
}
16465
16496
16497
+ // Deprecated, keep it only for backward compatibility.
16498
+ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) {
16499
+ UNUSED(ctx);
16500
+ ggml_graph_compute_v2(NULL, cgraph);
16501
+ }
16502
+
16466
16503
void ggml_graph_reset(struct ggml_cgraph * cgraph) {
16467
16504
for (int i = 0; i < cgraph->n_nodes; i++) {
16468
16505
struct ggml_tensor * grad = cgraph->grads[i];
0 commit comments