Skip to content

Commit 5feefb3

Browse files
committed
threading: add suspend/resume APIs, so it's possible to run a thread pool at session level
1 parent 5abb8ae commit 5feefb3

File tree

3 files changed

+204
-40
lines changed

3 files changed

+204
-40
lines changed

ggml-threading.c

Lines changed: 54 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,8 @@ struct ggml_threading_context {
194194
struct ggml_perf_stats wait_perf;
195195
struct ggml_perf_stats wakeup_perf;
196196

197+
atomic_bool suspending;
198+
197199
int64_t *stages_time;
198200
};
199201

@@ -252,15 +254,45 @@ static void ggml_threading_cond_wait(struct ggml_compute_state *state) {
252254
}
253255
}
254256

257+
// Suspend
258+
void ggml_threading_suspend(struct ggml_threading_context *ctx) {
259+
if (ctx->n_threads == 1) {
260+
return;
261+
}
262+
263+
struct ggml_compute_state_shared *shared = &ctx->shared;
264+
265+
ggml_spin_lock(&shared->spin);
266+
ctx->shared.wait_now = true;
267+
ggml_spin_unlock(&shared->spin);
268+
269+
const int n_worker_threads = ctx->n_threads - 1;
270+
271+
while (ctx->shared.n_waiting != n_worker_threads) {
272+
ggml_spin_pause();
273+
}
274+
275+
ggml_spin_lock(&shared->spin);
276+
ctx->suspending = true;
277+
ggml_spin_unlock(&shared->spin);
278+
PRINT_DEBUG("[main] saw %d workers waiting\n", n_worker_threads);
279+
}
280+
255281
// Wakeup all workers.
256282
//
257283
// Workers takes some time to wakeup, and has to lock spin after wakeup. Yield
258284
// is used to avoid signal frequently. Current implementation is highly
259285
// experimental. See tests/test-ggml-threading.c for details.
260286
//
261287
// NOTE: must be protected by shared->spin
262-
static void
263-
ggml_threading_wakeup_workers(struct ggml_compute_state_shared *shared) {
288+
void ggml_threading_resume(struct ggml_threading_context *ctx) {
289+
if (ctx->n_threads == 1) {
290+
return;
291+
}
292+
293+
struct ggml_compute_state_shared *shared = &ctx->shared;
294+
ggml_spin_lock(&shared->spin);
295+
264296
int64_t perf_cycles_0 = 0;
265297
int64_t perf_time_0 = 0;
266298

@@ -269,12 +301,11 @@ ggml_threading_wakeup_workers(struct ggml_compute_state_shared *shared) {
269301
perf_time_0 = ggml_time_us();
270302
}
271303

272-
shared->wait_now = false;
273-
274304
int loop_counter = 0;
275-
int notify_counter = 0;
276305
int64_t last_signal_time = 0;
277306

307+
shared->wait_now = false;
308+
278309
while (shared->n_waiting != 0) {
279310
ggml_spin_unlock(&shared->spin);
280311

@@ -294,22 +325,23 @@ ggml_threading_wakeup_workers(struct ggml_compute_state_shared *shared) {
294325
GGML_ASSERT(pthread_mutex_lock(&shared->mutex) == 0);
295326
GGML_ASSERT(pthread_cond_broadcast(&shared->cond) == 0);
296327
GGML_ASSERT(pthread_mutex_unlock(&shared->mutex) == 0);
297-
++notify_counter;
298328
last_signal_time = ggml_time_us();
299329

300330
ggml_spin_lock(&shared->spin);
301331
}
302332

333+
ctx->suspending = false;
334+
303335
if (shared->ctx->features & GGML_THREADING_FEATURE_PERF) {
304336
ggml_perf_collect(&shared->ctx->wakeup_perf, perf_cycles_0,
305337
perf_time_0);
306338
}
307339

308-
// if (notify_counter > 1) {
309-
// printf("%s: loop counter: %d, notify counter: %d\n", __func__,
310-
// loop_counter, notify_counter);
311-
// }
312-
UNUSED(notify_counter);
340+
ggml_spin_unlock(&shared->spin);
341+
}
342+
343+
bool ggml_threading_is_suspending(struct ggml_threading_context *ctx) {
344+
return ctx->suspending;
313345
}
314346

315347
// Setup workers for a task stage.
@@ -329,7 +361,9 @@ static void ggml_threading_setup_workers(struct ggml_threading_context *ctx,
329361

330362
if (current->parallel) {
331363
if (shared->n_waiting > 0) {
332-
ggml_threading_wakeup_workers(shared);
364+
ggml_spin_unlock(&shared->spin);
365+
ggml_threading_resume(ctx);
366+
ggml_spin_lock(&shared->spin);
333367
}
334368

335369
if ((ctx->features & GGML_THREADING_FEATURE_WAIT_ON_DONE) > 0) {
@@ -351,17 +385,11 @@ static void ggml_threading_setup_workers(struct ggml_threading_context *ctx,
351385
}
352386
} else if (current->wait) {
353387
if (shared->n_waiting < n_worker_threads) {
354-
shared->wait_now = true;
355-
PRINT_DEBUG("[main] wait_now was set, expect %d workers wait\n",
388+
PRINT_DEBUG("[main] wait_now will be set, expect %d workers wait\n",
356389
n_worker_threads);
357-
ggml_spin_unlock(&shared->spin);
358-
359-
while (shared->n_waiting != n_worker_threads) {
360-
ggml_spin_pause();
361-
}
362-
363-
ggml_spin_lock(&shared->spin);
364-
PRINT_DEBUG("[main] saw %d workers waiting\n", n_worker_threads);
390+
ggml_spin_unlock(&ctx->shared.spin);
391+
ggml_threading_suspend(ctx);
392+
ggml_spin_lock(&ctx->shared.spin);
365393
}
366394
}
367395

@@ -376,7 +404,7 @@ ggml_thread_ret_t ggml_threading_graph_compute_thread(void *data) {
376404

377405
struct ggml_compute_state_shared *shared = state->shared;
378406
GGML_ASSERT(shared);
379-
//GGML_ASSERT(shared->task_runner);
407+
// GGML_ASSERT(shared->task_runner);
380408

381409
shared->n_ready++;
382410

@@ -527,7 +555,7 @@ ggml_threading_compute_tensor(struct ggml_threading_context *ctx,
527555
GGML_ASSERT(profiles[0].id == 1);
528556

529557
memcpy(&node->task_profile, &profiles[0],
530-
sizeof(struct ggml_task_profile));
558+
sizeof(struct ggml_task_profile));
531559
runner = ctx->shared.task_runner;
532560
GGML_ASSERT(runner);
533561

@@ -572,6 +600,7 @@ ggml_threading_start(int n_threads, ggml_threading_thread_runner *thread_runner,
572600
ctx->n_threads = n_threads;
573601
ctx->features = features;
574602
ctx->stages_time = stages_time;
603+
ctx->suspending = false;
575604

576605
int n_workers = n_threads - 1;
577606
if (n_workers > 0) {
@@ -633,9 +662,7 @@ void ggml_threading_stop(struct ggml_threading_context *ctx) {
633662
PRINT_DEBUG("[main] stopping thread pool ...\n");
634663
ctx->shared.stop = true;
635664

636-
ggml_spin_lock(&ctx->shared.spin);
637-
ggml_threading_wakeup_workers(&ctx->shared);
638-
ggml_spin_unlock(&ctx->shared.spin);
665+
ggml_threading_resume(ctx);
639666

640667
for (int j = 0; j < ctx->n_threads - 1; j++) {
641668
GGML_ASSERT(pthread_join(ctx->workers[j].thrd, NULL) == 0);

ggml-threading.h

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,12 @@ enum ggml_threading_features {
2525
typedef ggml_thread_ret_t(ggml_threading_thread_runner)(void *data);
2626

2727
// Init and start underlying workers if n_threads > 1.
28+
// n_threads: number of threads (including caller) involving in computing tasks.
2829
//
2930
// thread: optional OS thread runner, default value:
3031
// `ggml_threading_graph_compute_thread`.
3132
//
32-
// task_runner: default task runner, nullable wheen tensor.runner is not NULL.
33+
// task_runner: default task runner, nullable when tensor.runner is not NULL.
3334
// Overridden by tensor.runner.
3435
// features: configure threading behaviour, optional.
3536
// threading additional features. see `ggml_threading_feature`, default 0.
@@ -41,9 +42,18 @@ ggml_threading_start(int n_threads, ggml_threading_thread_runner *thread,
4142
enum ggml_threading_features features,
4243
int64_t stages_time[3]);
4344

45+
// Suspend worker threads.
46+
void ggml_threading_suspend(struct ggml_threading_context *ctx);
47+
48+
// Resume worker threads.
49+
void ggml_threading_resume(struct ggml_threading_context *ctx);
50+
4451
// Stop workers (if exist), free memories (including the ctx).
4552
void ggml_threading_stop(struct ggml_threading_context *ctx);
4653

54+
// Is all worker threads suspending?
55+
bool ggml_threading_is_suspending(struct ggml_threading_context *ctx);
56+
4757
// The default implementation of `ggml_threading_thread_runner`
4858
ggml_thread_ret_t ggml_threading_graph_compute_thread(void *data);
4959

0 commit comments

Comments
 (0)