Skip to content

Commit 6f2a61e

Browse files
committed
Rework scheduling algorithm.
1 parent 2035a3c commit 6f2a61e

File tree

1 file changed

+56
-63
lines changed

1 file changed

+56
-63
lines changed

ggml.c

Lines changed: 56 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -9279,23 +9279,22 @@ static void ggml_graph_compute_thread(void * data) {
92799279

92809280
void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) {
92819281
const int n_threads = cgraph->n_threads;
9282-
struct ggml_compute_state * workers = n_threads > 1 ? alloca(sizeof(struct ggml_compute_state)*(n_threads - 1)) : NULL;
9282+
const int max_requests = n_threads * 5;
9283+
struct ggml_compute_state * workers = alloca(sizeof(struct ggml_compute_state)*(max_requests));
92839284

92849285
// create thread pool
9285-
if (n_threads > 1) {
9286-
ctx->tpool = thpool_init(n_threads);
9287-
for (int j = 0; j < n_threads - 1; j++) {
9288-
workers[j] = (struct ggml_compute_state) {
9289-
.params = {
9290-
.type = GGML_TASK_COMPUTE,
9291-
.ith = j + 1,
9292-
.nth = n_threads,
9293-
.wsize = cgraph->work ? ggml_nbytes(cgraph->work) : 0,
9294-
.wdata = cgraph->work ? cgraph->work->data : NULL,
9295-
},
9296-
.node = NULL,
9297-
};
9298-
}
9286+
ctx->tpool = thpool_init(n_threads);
9287+
for (int j = 0; j < n_threads - 1; j++) {
9288+
workers[j] = (struct ggml_compute_state) {
9289+
.params = {
9290+
.type = GGML_TASK_COMPUTE,
9291+
.ith = j + 1,
9292+
.nth = n_threads,
9293+
.wsize = cgraph->work ? ggml_nbytes(cgraph->work) : 0,
9294+
.wdata = cgraph->work ? cgraph->work->data : NULL,
9295+
},
9296+
.node = NULL,
9297+
};
92999298
}
93009299

93019300
// initialize tasks + work buffer
@@ -9505,6 +9504,8 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
95059504

95069505
const int64_t perf_start_cycles = ggml_perf_cycles();
95079506
const int64_t perf_start_time_us = ggml_perf_time_us();
9507+
const size_t wsize = cgraph->work ? ggml_nbytes(cgraph->work) : 0;
9508+
const void* wdata = cgraph->work ? cgraph->work->data : NULL;
95089509

95099510
for (int i = 0; i < cgraph->n_nodes; i++) {
95109511
GGML_PRINT_DEBUG_5("%s: %d/%d\n", __func__, i, cgraph->n_nodes);
@@ -9524,52 +9525,31 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
95249525
const int64_t perf_node_start_cycles = ggml_perf_cycles();
95259526
const int64_t perf_node_start_time_us = ggml_perf_time_us();
95269527

9527-
// INIT
9528-
struct ggml_compute_params params = {
9529-
.type = GGML_TASK_INIT,
9530-
.ith = 0,
9531-
.nth = node->n_tasks,
9532-
.wsize = cgraph->work ? ggml_nbytes(cgraph->work) : 0,
9533-
.wdata = cgraph->work ? cgraph->work->data : NULL,
9534-
};
9535-
9536-
ggml_compute_forward(&params, node);
9537-
95389528
int next_task = 0;
95399529

95409530
// COMPUTE
9541-
if (node->n_tasks > 1) {
9542-
// launch thread pool
9543-
for (int j = 0; j < n_threads - 1; j++) {
9544-
workers[j].params = (struct ggml_compute_params) {
9545-
.type = GGML_TASK_COMPUTE,
9546-
.ith = j + 1,
9547-
.nth = node->n_tasks,
9548-
.wsize = cgraph->work ? ggml_nbytes(cgraph->work) : 0,
9549-
.wdata = cgraph->work ? cgraph->work->data : NULL,
9550-
};
9551-
workers[j].node = node;
9552-
thpool_add_work(ctx->tpool, ggml_graph_compute_thread, &workers[j]);
9553-
}
9554-
}
9555-
else
95569531
{
95579532
int start = i;
9558-
int end = i + 1;
9559-
while (end < cgraph->n_nodes && next_task < n_threads && (end - start) < n_threads * 2)
9533+
int end = i;
9534+
while (end < cgraph->n_nodes && (end - start) < n_threads * 2)
95609535
{
95619536
struct ggml_tensor * next = cgraph->nodes[end];
95629537
end++;
95639538

9564-
if (next->n_tasks != 1)
9539+
// already scheduled
9540+
if (next->n_tasks == 0)
9541+
continue;
9542+
9543+
// if we have slots
9544+
if (next_task + next->n_tasks > max_requests)
95659545
continue;
95669546

95679547
// check src depedency
95689548
bool is_dep = false;
95699549
for (int k = start; k < end; k++)
95709550
{
9571-
struct ggml_tensor * node = cgraph->nodes[k];
9572-
if (next->src0 == node || next->src1 == node)
9551+
struct ggml_tensor * prev = cgraph->nodes[k];
9552+
if (next->src0 == prev || next->src1 == prev)
95739553
{
95749554
is_dep = true;
95759555
break;
@@ -9579,29 +9559,42 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
95799559
if (is_dep)
95809560
continue;
95819561

9582-
workers[next_task].params = (struct ggml_compute_params) {
9583-
.type = GGML_TASK_INIT,
9584-
.ith = 0,
9585-
.nth = 1,
9586-
.wsize = 0,
9587-
.wdata = NULL,
9588-
};
9589-
workers[next_task].node = next;
9562+
if (next->n_tasks > 1)
9563+
{
9564+
// run INIT in main thread if it is multi thread operator
9565+
struct ggml_compute_params params = {
9566+
.type = GGML_TASK_INIT,
9567+
.ith = 0,
9568+
.nth = next->n_tasks,
9569+
.wsize = wsize,
9570+
.wdata = wdata,
9571+
};
9572+
9573+
ggml_compute_forward(&params, next);
9574+
}
95909575

9591-
thpool_add_work(ctx->tpool, ggml_graph_compute_thread, &workers[next_task]);
9576+
for (int j = 0; j < next->n_tasks; j++) {
9577+
workers[next_task].params = (struct ggml_compute_params){
9578+
// single thread operator runs INIT in worker thread
9579+
.type = next->n_tasks == 1 ? GGML_TASK_INIT : GGML_TASK_COMPUTE,
9580+
.ith = j,
9581+
.nth = next->n_tasks,
9582+
9583+
// TODO: Potential race on wdata
9584+
.wsize = wsize,
9585+
.wdata = wdata,
9586+
};
9587+
workers[next_task].node = next;
9588+
9589+
thpool_add_work(ctx->tpool, ggml_graph_compute_thread, &workers[next_task]);
9590+
next_task++;
9591+
}
95929592
next->n_tasks = 0; // indicate this node is caculated
9593-
next_task++;
9594-
//printf("Combine task [%d, %d]\n", start, end);
95959593
}
95969594
}
95979595

9598-
params.type = GGML_TASK_COMPUTE;
9599-
ggml_compute_forward(&params, node);
9600-
96019596
// wait for thread pool
9602-
if (node->n_tasks > 1 || next_task != 0) {
9603-
thpool_wait(ctx->tpool);
9604-
}
9597+
thpool_wait(ctx->tpool);
96059598
#if 0
96069599
// FINALIZE
96079600
if (node->n_tasks > 1) {

0 commit comments

Comments
 (0)