@@ -16605,8 +16605,6 @@ struct ggml_cgraph ggml_build_forward(struct ggml_tensor * tensor) {
16605
16605
/*.n_nodes =*/ 0,
16606
16606
/*.n_leafs =*/ 0,
16607
16607
/*.n_threads =*/ GGML_DEFAULT_N_THREADS,
16608
- /*.work_size =*/ 0,
16609
- /*.work =*/ NULL,
16610
16608
/*.nodes =*/ { NULL },
16611
16609
/*.grads =*/ { NULL },
16612
16610
/*.leafs =*/ { NULL },
@@ -16778,6 +16776,7 @@ void clear_numa_thread_affinity(void) {}
16778
16776
16779
16777
struct ggml_compute_state_shared {
16780
16778
struct ggml_cgraph * cgraph;
16779
+ struct ggml_cgraph_context * cgraph_ctx;
16781
16780
16782
16781
int64_t perf_node_start_cycles;
16783
16782
int64_t perf_node_start_time_us;
@@ -16807,6 +16806,7 @@ static void ggml_graph_compute_perf_stats_node(struct ggml_tensor * node, const
16807
16806
static thread_ret_t ggml_graph_compute_thread(void * data) {
16808
16807
struct ggml_compute_state * state = (struct ggml_compute_state *) data;
16809
16808
struct ggml_cgraph * cgraph = state->shared->cgraph;
16809
+ struct ggml_cgraph_context * ctx = state->shared->cgraph_ctx;
16810
16810
16811
16811
const int n_threads = state->shared->n_threads;
16812
16812
set_numa_thread_affinity(state->ith, n_threads);
@@ -16821,8 +16821,8 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
16821
16821
/*.type =*/ GGML_TASK_FINALIZE,
16822
16822
/*.ith =*/ 0,
16823
16823
/*.nth =*/ 0,
16824
- /*.wsize =*/ cgraph->work ? ggml_nbytes(cgraph->work) : 0 ,
16825
- /*.wdata =*/ cgraph->work ? cgraph->work->data : NULL ,
16824
+ /*.wsize =*/ ctx->work_size ,
16825
+ /*.wdata =*/ ctx->work_data ,
16826
16826
};
16827
16827
16828
16828
if (node_n != -1) {
@@ -16889,8 +16889,8 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
16889
16889
/*.type =*/ GGML_TASK_COMPUTE,
16890
16890
/*.ith =*/ state->ith,
16891
16891
/*.nth =*/ node->n_tasks,
16892
- /*.wsize =*/ cgraph->work ? ggml_nbytes(cgraph->work) : 0 ,
16893
- /*.wdata =*/ cgraph->work ? cgraph->work->data : NULL ,
16892
+ /*.wsize =*/ ctx->work_size ,
16893
+ /*.wdata =*/ ctx->work_data ,
16894
16894
};
16895
16895
16896
16896
if (state->ith < node->n_tasks) {
@@ -16901,23 +16901,20 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
16901
16901
return 0;
16902
16902
}
16903
16903
16904
- void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) {
16905
- const int n_threads = cgraph->n_threads;
16904
+ // Prepare for graph computing.
16905
+ // Will set: node->n_tasks, ctx->{work_size, planned}
16906
+ void ggml_graph_compute_plan(struct ggml_cgraph_context * ctx, struct ggml_cgraph * cgraph) {
16907
+ GGML_ASSERT(ctx);
16908
+ // This function is actually reentrant, but duplicate calls is unnecessary.
16909
+ GGML_ASSERT(ctx->work_size == 0);
16910
+ GGML_ASSERT(ctx->work_data == NULL);
16911
+ GGML_ASSERT(!ctx->planned);
16906
16912
16907
- struct ggml_compute_state_shared state_shared = {
16908
- /*.cgraph =*/ cgraph,
16909
- /*.perf_node_start_cycles =*/ 0,
16910
- /*.perf_node_start_time_us =*/ 0,
16911
- /*.n_threads =*/ n_threads,
16912
- /*.n_active =*/ n_threads,
16913
- /*.node_n =*/ -1,
16914
- };
16915
- struct ggml_compute_state * workers = alloca(sizeof(struct ggml_compute_state)*n_threads);
16913
+ int n_threads = cgraph->n_threads;
16914
+ size_t work_size = 0;
16916
16915
16917
16916
// initialize tasks + work buffer
16918
16917
{
16919
- size_t work_size = 0;
16920
-
16921
16918
// thread scheduling for the different operations
16922
16919
for (int i = 0; i < cgraph->n_nodes; i++) {
16923
16920
struct ggml_tensor * node = cgraph->nodes[i];
@@ -17247,19 +17244,53 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
17247
17244
} break;
17248
17245
}
17249
17246
}
17247
+ }
17250
17248
17251
- if (cgraph->work != NULL && work_size > cgraph->work_size) {
17252
- GGML_ASSERT(false); // TODO: better handling
17253
- }
17249
+ if (work_size > 0) {
17250
+ work_size += CACHE_LINE_SIZE*(n_threads - 1);
17251
+ }
17252
+
17253
+ ctx->work_size = work_size;
17254
+ ctx->work_data = NULL;
17255
+ ctx->planned = true;
17256
+ }
17254
17257
17255
- if (work_size > 0 && cgraph->work == NULL) {
17256
- cgraph->work_size = work_size + CACHE_LINE_SIZE*(n_threads - 1);
17258
+ void ggml_graph_compute_v2(struct ggml_cgraph_context * ctx, struct ggml_cgraph * cgraph) {
17259
+ if (ctx == NULL) {
17260
+ ctx = alloca(sizeof(struct ggml_cgraph_context));
17261
+ GGML_ASSERT(ctx);
17262
+ ctx->work_size = 0;
17263
+ ctx->work_data = NULL;
17264
+ ctx->planned = false;
17265
+ } else {
17266
+ // The work_size and work_data MAY have default values even if has been planned.
17267
+ if (ctx->work_size > 0) {
17268
+ GGML_ASSERT(ctx->work_data);
17269
+ }
17270
+ }
17257
17271
17258
- GGML_PRINT_DEBUG("%s: allocating work buffer for graph (%zu bytes)\n", __func__, cgraph->work_size);
17259
- cgraph->work = ggml_new_tensor_1d(ctx, GGML_TYPE_I8, cgraph->work_size);
17272
+ if (!ctx->planned) {
17273
+ ggml_graph_compute_plan(ctx, cgraph);
17274
+ if (ctx->work_size > 0) {
17275
+ ctx->work_data = malloc(ctx->work_size * sizeof(GGML_TYPE_I8));
17276
+ GGML_ASSERT(ctx->work_data);
17277
+ GGML_PRINT_DEBUG("%s: allocating work buffer for graph (%zu bytes)\n", __func__, work_size);
17260
17278
}
17261
17279
}
17262
17280
17281
+ const int n_threads = cgraph->n_threads;
17282
+
17283
+ struct ggml_compute_state_shared state_shared = {
17284
+ /*.cgraph =*/ cgraph,
17285
+ /*.cgraph_ctx =*/ ctx,
17286
+ /*.perf_node_start_cycles =*/ 0,
17287
+ /*.perf_node_start_time_us =*/ 0,
17288
+ /*.n_threads =*/ n_threads,
17289
+ /*.n_active =*/ n_threads,
17290
+ /*.node_n =*/ -1,
17291
+ };
17292
+ struct ggml_compute_state * workers = alloca(sizeof(struct ggml_compute_state)*n_threads);
17293
+
17263
17294
// create thread pool
17264
17295
if (n_threads > 1) {
17265
17296
for (int j = 1; j < n_threads; ++j) {
@@ -17311,6 +17342,12 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
17311
17342
}
17312
17343
}
17313
17344
17345
+ // Deprecated, keep it only for backward compatibility.
17346
+ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) {
17347
+ UNUSED(ctx);
17348
+ ggml_graph_compute_v2(NULL, cgraph);
17349
+ }
17350
+
17314
17351
void ggml_graph_reset(struct ggml_cgraph * cgraph) {
17315
17352
for (int i = 0; i < cgraph->n_nodes; i++) {
17316
17353
struct ggml_tensor * grad = cgraph->grads[i];
0 commit comments