@@ -194,6 +194,8 @@ struct ggml_threading_context {
194
194
struct ggml_perf_stats wait_perf ;
195
195
struct ggml_perf_stats wakeup_perf ;
196
196
197
+ atomic_bool suspending ;
198
+
197
199
int64_t * stages_time ;
198
200
};
199
201
@@ -252,15 +254,45 @@ static void ggml_threading_cond_wait(struct ggml_compute_state *state) {
252
254
}
253
255
}
254
256
257
+ // Suspend
258
+ void ggml_threading_suspend (struct ggml_threading_context * ctx ) {
259
+ if (ctx -> n_threads == 1 ) {
260
+ return ;
261
+ }
262
+
263
+ struct ggml_compute_state_shared * shared = & ctx -> shared ;
264
+
265
+ ggml_spin_lock (& shared -> spin );
266
+ ctx -> shared .wait_now = true;
267
+ ggml_spin_unlock (& shared -> spin );
268
+
269
+ const int n_worker_threads = ctx -> n_threads - 1 ;
270
+
271
+ while (ctx -> shared .n_waiting != n_worker_threads ) {
272
+ ggml_spin_pause ();
273
+ }
274
+
275
+ ggml_spin_lock (& shared -> spin );
276
+ ctx -> suspending = true;
277
+ ggml_spin_unlock (& shared -> spin );
278
+ PRINT_DEBUG ("[main] saw %d workers waiting\n" , n_worker_threads );
279
+ }
280
+
255
281
// Wakeup all workers.
256
282
//
257
283
// Workers takes some time to wakeup, and has to lock spin after wakeup. Yield
258
284
// is used to avoid signal frequently. Current implementation is highly
259
285
// experimental. See tests/test-ggml-threading.c for details.
260
286
//
261
287
// NOTE: must be protected by shared->spin
262
- static void
263
- ggml_threading_wakeup_workers (struct ggml_compute_state_shared * shared ) {
288
+ void ggml_threading_resume (struct ggml_threading_context * ctx ) {
289
+ if (ctx -> n_threads == 1 ) {
290
+ return ;
291
+ }
292
+
293
+ struct ggml_compute_state_shared * shared = & ctx -> shared ;
294
+ ggml_spin_lock (& shared -> spin );
295
+
264
296
int64_t perf_cycles_0 = 0 ;
265
297
int64_t perf_time_0 = 0 ;
266
298
@@ -269,12 +301,11 @@ ggml_threading_wakeup_workers(struct ggml_compute_state_shared *shared) {
269
301
perf_time_0 = ggml_time_us ();
270
302
}
271
303
272
- shared -> wait_now = false;
273
-
274
304
int loop_counter = 0 ;
275
- int notify_counter = 0 ;
276
305
int64_t last_signal_time = 0 ;
277
306
307
+ shared -> wait_now = false;
308
+
278
309
while (shared -> n_waiting != 0 ) {
279
310
ggml_spin_unlock (& shared -> spin );
280
311
@@ -294,22 +325,23 @@ ggml_threading_wakeup_workers(struct ggml_compute_state_shared *shared) {
294
325
GGML_ASSERT (pthread_mutex_lock (& shared -> mutex ) == 0 );
295
326
GGML_ASSERT (pthread_cond_broadcast (& shared -> cond ) == 0 );
296
327
GGML_ASSERT (pthread_mutex_unlock (& shared -> mutex ) == 0 );
297
- ++ notify_counter ;
298
328
last_signal_time = ggml_time_us ();
299
329
300
330
ggml_spin_lock (& shared -> spin );
301
331
}
302
332
333
+ ctx -> suspending = false;
334
+
303
335
if (shared -> ctx -> features & GGML_THREADING_FEATURE_PERF ) {
304
336
ggml_perf_collect (& shared -> ctx -> wakeup_perf , perf_cycles_0 ,
305
337
perf_time_0 );
306
338
}
307
339
308
- // if (notify_counter > 1) {
309
- // printf("%s: loop counter: %d, notify counter: %d\n", __func__,
310
- // loop_counter, notify_counter);
311
- // }
312
- UNUSED ( notify_counter ) ;
340
+ ggml_spin_unlock ( & shared -> spin );
341
+ }
342
+
343
+ bool ggml_threading_is_suspending ( struct ggml_threading_context * ctx ) {
344
+ return ctx -> suspending ;
313
345
}
314
346
315
347
// Setup workers for a task stage.
@@ -329,7 +361,9 @@ static void ggml_threading_setup_workers(struct ggml_threading_context *ctx,
329
361
330
362
if (current -> parallel ) {
331
363
if (shared -> n_waiting > 0 ) {
332
- ggml_threading_wakeup_workers (shared );
364
+ ggml_spin_unlock (& shared -> spin );
365
+ ggml_threading_resume (ctx );
366
+ ggml_spin_lock (& shared -> spin );
333
367
}
334
368
335
369
if ((ctx -> features & GGML_THREADING_FEATURE_WAIT_ON_DONE ) > 0 ) {
@@ -351,17 +385,11 @@ static void ggml_threading_setup_workers(struct ggml_threading_context *ctx,
351
385
}
352
386
} else if (current -> wait ) {
353
387
if (shared -> n_waiting < n_worker_threads ) {
354
- shared -> wait_now = true;
355
- PRINT_DEBUG ("[main] wait_now was set, expect %d workers wait\n" ,
388
+ PRINT_DEBUG ("[main] wait_now will be set, expect %d workers wait\n" ,
356
389
n_worker_threads );
357
- ggml_spin_unlock (& shared -> spin );
358
-
359
- while (shared -> n_waiting != n_worker_threads ) {
360
- ggml_spin_pause ();
361
- }
362
-
363
- ggml_spin_lock (& shared -> spin );
364
- PRINT_DEBUG ("[main] saw %d workers waiting\n" , n_worker_threads );
390
+ ggml_spin_unlock (& ctx -> shared .spin );
391
+ ggml_threading_suspend (ctx );
392
+ ggml_spin_lock (& ctx -> shared .spin );
365
393
}
366
394
}
367
395
@@ -376,7 +404,7 @@ ggml_thread_ret_t ggml_threading_graph_compute_thread(void *data) {
376
404
377
405
struct ggml_compute_state_shared * shared = state -> shared ;
378
406
GGML_ASSERT (shared );
379
- //GGML_ASSERT(shared->task_runner);
407
+ // GGML_ASSERT(shared->task_runner);
380
408
381
409
shared -> n_ready ++ ;
382
410
@@ -527,7 +555,7 @@ ggml_threading_compute_tensor(struct ggml_threading_context *ctx,
527
555
GGML_ASSERT (profiles [0 ].id == 1 );
528
556
529
557
memcpy (& node -> task_profile , & profiles [0 ],
530
- sizeof (struct ggml_task_profile ));
558
+ sizeof (struct ggml_task_profile ));
531
559
runner = ctx -> shared .task_runner ;
532
560
GGML_ASSERT (runner );
533
561
@@ -572,6 +600,7 @@ ggml_threading_start(int n_threads, ggml_threading_thread_runner *thread_runner,
572
600
ctx -> n_threads = n_threads ;
573
601
ctx -> features = features ;
574
602
ctx -> stages_time = stages_time ;
603
+ ctx -> suspending = false;
575
604
576
605
int n_workers = n_threads - 1 ;
577
606
if (n_workers > 0 ) {
@@ -633,9 +662,7 @@ void ggml_threading_stop(struct ggml_threading_context *ctx) {
633
662
PRINT_DEBUG ("[main] stopping thread pool ...\n" );
634
663
ctx -> shared .stop = true;
635
664
636
- ggml_spin_lock (& ctx -> shared .spin );
637
- ggml_threading_wakeup_workers (& ctx -> shared );
638
- ggml_spin_unlock (& ctx -> shared .spin );
665
+ ggml_threading_resume (ctx );
639
666
640
667
for (int j = 0 ; j < ctx -> n_threads - 1 ; j ++ ) {
641
668
GGML_ASSERT (pthread_join (ctx -> workers [j ].thrd , NULL ) == 0 );
0 commit comments