Skip to content

Commit 3b03df5

Browse files
committed
look forward more
1 parent 921296c commit 3b03df5

File tree

1 file changed

+37
-43
lines changed

1 file changed

+37
-43
lines changed

ggml.c

Lines changed: 37 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -9249,16 +9249,10 @@ typedef int ggml_lock_t;
92499249

92509250
#endif
92519251

9252-
struct ggml_compute_state_shared {
9253-
int n_threads;
9254-
};
9255-
92569252
struct ggml_compute_state {
92579253

92589254
struct ggml_compute_params params;
92599255
struct ggml_tensor * node;
9260-
9261-
struct ggml_compute_state_shared * shared;
92629256
};
92639257

92649258
static void ggml_graph_compute_thread(void * data) {
@@ -9284,9 +9278,6 @@ static void ggml_graph_compute_thread(void * data) {
92849278

92859279
void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) {
92869280
const int n_threads = cgraph->n_threads;
9287-
struct ggml_compute_state_shared state_shared = {
9288-
/*.n_threads =*/ n_threads,
9289-
};
92909281
struct ggml_compute_state * workers = n_threads > 1 ? alloca(sizeof(struct ggml_compute_state)*(n_threads - 1)) : NULL;
92919282

92929283
// create thread pool
@@ -9302,7 +9293,6 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
93029293
.wdata = cgraph->work ? cgraph->work->data : NULL,
93039294
},
93049295
.node = NULL,
9305-
.shared = &state_shared,
93069296
};
93079297
}
93089298
}
@@ -9520,6 +9510,11 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
95209510

95219511
struct ggml_tensor * node = cgraph->nodes[i];
95229512

9513+
if (node->n_tasks == 0)
9514+
{
9515+
// no work need to be done.
9516+
continue;
9517+
}
95239518
// TODO: this could be used to avoid unnecessary computations, but it needs to be improved
95249519
//if (node->grad == NULL && node->perf_runs > 0) {
95259520
// continue;
@@ -9558,46 +9553,45 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
95589553
}
95599554
else
95609555
{
9561-
if (i + 1 < cgraph->n_nodes)
9556+
int start = i;
9557+
int end = i + 1;
9558+
while (end < cgraph->n_nodes && next_task < n_threads && (end - start) < n_threads * 2)
95629559
{
9563-
struct ggml_tensor * next = cgraph->nodes[i + 1];
9564-
if (next->src0 != node && next->src1 != node && next->n_tasks == 1)
9560+
struct ggml_tensor * next = cgraph->nodes[end];
9561+
end++;
9562+
9563+
if (next->n_tasks != 1)
9564+
continue;
9565+
9566+
// check src depedency
9567+
bool is_dep = false;
9568+
for (int k = start; k < end; k++)
95659569
{
9566-
workers[next_task].params = (struct ggml_compute_params) {
9567-
.type = GGML_TASK_COMPUTE | GGML_TASK_INIT,
9568-
.ith = 0,
9569-
.nth = 1,
9570-
.wsize = 0,
9571-
.wdata = NULL,
9572-
};
9573-
workers[next_task].node = next;
9574-
thpool_add_work(ctx->tpool, ggml_graph_compute_thread, &workers[next_task]);
9575-
next_task++;
9576-
9577-
if (i + 2 < cgraph->n_nodes)
9570+
struct ggml_tensor * node = cgraph->nodes[k];
9571+
if (next->src0 == node || next->src1 == node)
95789572
{
9579-
struct ggml_tensor * prev = cgraph->nodes[i + 1];
9580-
struct ggml_tensor * next = cgraph->nodes[i + 2];
9581-
if (next->src0 != node && next->src1 != node && next->n_tasks == 1 &&
9582-
next->src0 != prev && next->src1 != prev
9583-
)
9584-
{
9585-
workers[next_task].params = (struct ggml_compute_params) {
9586-
.type = GGML_TASK_COMPUTE | GGML_TASK_INIT,
9587-
.ith = 0,
9588-
.nth = 1,
9589-
.wsize = 0,
9590-
.wdata = NULL,
9591-
};
9592-
workers[next_task].node = next;
9593-
thpool_add_work(ctx->tpool, ggml_graph_compute_thread, &workers[next_task]);
9594-
next_task++;
9595-
}
9573+
is_dep = true;
9574+
break;
95969575
}
95979576
}
9598-
}
95999577

9578+
if (is_dep)
9579+
continue;
96009580

9581+
workers[next_task].params = (struct ggml_compute_params) {
9582+
.type = GGML_TASK_COMPUTE | GGML_TASK_INIT,
9583+
.ith = 0,
9584+
.nth = 1,
9585+
.wsize = 0,
9586+
.wdata = NULL,
9587+
};
9588+
workers[next_task].node = next;
9589+
9590+
thpool_add_work(ctx->tpool, ggml_graph_compute_thread, &workers[next_task]);
9591+
next->n_tasks = 0; // indicate this node is caculated
9592+
next_task++;
9593+
//printf("Combine task [%d, %d]\n", start, end);
9594+
}
96019595
}
96029596

96039597
params.type = GGML_TASK_COMPUTE;

0 commit comments

Comments
 (0)