Skip to content

Commit 286c5b3

Browse files
committed
threadng: remove unnecessary spin lock/unlock from suspend/resume; add more tests
1 parent 5feefb3 commit 286c5b3

File tree

2 files changed

+19
-28
lines changed

2 files changed

+19
-28
lines changed

ggml-threading.c

Lines changed: 4 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -260,22 +260,17 @@ void ggml_threading_suspend(struct ggml_threading_context *ctx) {
260260
return;
261261
}
262262

263-
struct ggml_compute_state_shared *shared = &ctx->shared;
264-
265-
ggml_spin_lock(&shared->spin);
263+
PRINT_DEBUG("[main] wait_now will be set, expect %d workers wait\n",
264+
n_worker_threads);
266265
ctx->shared.wait_now = true;
267-
ggml_spin_unlock(&shared->spin);
268266

269267
const int n_worker_threads = ctx->n_threads - 1;
270-
271268
while (ctx->shared.n_waiting != n_worker_threads) {
272269
ggml_spin_pause();
273270
}
274271

275-
ggml_spin_lock(&shared->spin);
276-
ctx->suspending = true;
277-
ggml_spin_unlock(&shared->spin);
278272
PRINT_DEBUG("[main] saw %d workers waiting\n", n_worker_threads);
273+
ctx->suspending = true;
279274
}
280275

281276
// Wakeup all workers.
@@ -291,7 +286,6 @@ void ggml_threading_resume(struct ggml_threading_context *ctx) {
291286
}
292287

293288
struct ggml_compute_state_shared *shared = &ctx->shared;
294-
ggml_spin_lock(&shared->spin);
295289

296290
int64_t perf_cycles_0 = 0;
297291
int64_t perf_time_0 = 0;
@@ -307,8 +301,6 @@ void ggml_threading_resume(struct ggml_threading_context *ctx) {
307301
shared->wait_now = false;
308302

309303
while (shared->n_waiting != 0) {
310-
ggml_spin_unlock(&shared->spin);
311-
312304
if (loop_counter > 0) {
313305
ggml_spin_pause();
314306
if (loop_counter > 3) {
@@ -326,18 +318,14 @@ void ggml_threading_resume(struct ggml_threading_context *ctx) {
326318
GGML_ASSERT(pthread_cond_broadcast(&shared->cond) == 0);
327319
GGML_ASSERT(pthread_mutex_unlock(&shared->mutex) == 0);
328320
last_signal_time = ggml_time_us();
329-
330-
ggml_spin_lock(&shared->spin);
331321
}
332322

333323
ctx->suspending = false;
334324

335325
if (shared->ctx->features & GGML_THREADING_FEATURE_PERF) {
336326
ggml_perf_collect(&shared->ctx->wakeup_perf, perf_cycles_0,
337327
perf_time_0);
338-
}
339-
340-
ggml_spin_unlock(&shared->spin);
328+
};
341329
}
342330

343331
bool ggml_threading_is_suspending(struct ggml_threading_context *ctx) {
@@ -385,8 +373,6 @@ static void ggml_threading_setup_workers(struct ggml_threading_context *ctx,
385373
}
386374
} else if (current->wait) {
387375
if (shared->n_waiting < n_worker_threads) {
388-
PRINT_DEBUG("[main] wait_now will be set, expect %d workers wait\n",
389-
n_worker_threads);
390376
ggml_spin_unlock(&ctx->shared.spin);
391377
ggml_threading_suspend(ctx);
392378
ggml_spin_lock(&ctx->shared.spin);

tests/test-ggml-threading.c

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,7 @@ lifecycle_runner(const struct ggml_compute_params *params,
214214
}
215215

216216
// Test thread lifecycle: start -> suspend -> resume -> stop
217-
static int test_lifecycle(void) {
217+
static int test_lifecycle(bool wait_on_done) {
218218
struct ggml_tensor node;
219219
memset(&node, 0, sizeof(struct ggml_tensor));
220220

@@ -243,14 +243,15 @@ static int test_lifecycle(void) {
243243
int threads_arr_len = sizeof(threads_arr) / sizeof(threads_arr[0]);
244244
int n_threads = 1;
245245

246+
enum ggml_threading_features features =
247+
wait_on_done ? GGML_THREADING_FEATURE_NONE
248+
: GGML_THREADING_FEATURE_WAIT_ON_DONE;
246249
for (int i = 0; i < threads_arr_len; i++) {
247250
n_threads = threads_arr[i];
248251
int start_time = (int)ggml_time_ms();
249-
ctx = ggml_threading_start(
250-
n_threads, NULL, lifecycle_runner,
251-
/*features*/ GGML_THREADING_FEATURE_WAIT_ON_DONE |
252-
GGML_THREADING_FEATURE_PERF,
253-
/*stages_time*/ NULL);
252+
ctx = ggml_threading_start(n_threads, NULL, lifecycle_runner,
253+
features | GGML_THREADING_FEATURE_PERF,
254+
/*stages_time*/ NULL);
254255
int elapsed = (int)ggml_time_ms() - start_time;
255256
if (elapsed > 5 * n_threads) {
256257
printf("[test-ggml-threading] %s: it took %d ms to start %d worker "
@@ -547,13 +548,17 @@ int main(void) {
547548
}
548549

549550
// lifecycle.
550-
{
551-
printf("[test-ggml-threading] test lifecycle ...\n");
551+
for (int i = 0; i < 2; i++) {
552+
bool wait_on_done = (i == 1);
553+
printf("[test-ggml-threading] test lifecycle (want_on_done = %d) ...\n",
554+
wait_on_done);
552555
++n_tests;
553556

554-
if (test_lifecycle() == 0) {
557+
if (test_lifecycle(wait_on_done) == 0) {
555558
++n_passed;
556-
printf("[test-ggml-threading] test lifecycle: ok\n\n");
559+
printf("[test-ggml-threading] test lifecycle (want_on_done = %d): "
560+
"ok\n\n",
561+
wait_on_done);
557562
}
558563
}
559564

0 commit comments

Comments
 (0)