@@ -1751,7 +1751,7 @@ struct ggml_compute_state_shared {
1751
1751
int64_t perf_node_start_cycles;
1752
1752
int64_t perf_node_start_time_us;
1753
1753
1754
- const int n_threads;
1754
+ int n_threads;
1755
1755
1756
1756
// synchronization primitives
1757
1757
atomic_int n_active; // num active threads
@@ -19486,12 +19486,6 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
19486
19486
if (n_threads <= 0) {
19487
19487
n_threads = GGML_DEFAULT_N_THREADS;
19488
19488
}
19489
- #if defined(GGML_USE_OPENMP)
19490
- // Limit the number of threads used to avoid deadlock
19491
- // ref: https://github.com/ggerganov/llama.cpp/pull/7606
19492
- n_threads = MIN(n_threads, omp_get_max_threads());
19493
- n_threads = MIN(n_threads, omp_get_thread_limit());
19494
- #endif
19495
19489
19496
19490
size_t work_size = 0;
19497
19491
@@ -19676,9 +19670,20 @@ static enum ggml_status ggml_graph_compute_parallel(struct ggml_compute_state *
19676
19670
enum ggml_status compute_status = GGML_STATUS_SUCCESS;
19677
19671
19678
19672
#ifdef GGML_USE_OPENMP
19679
- #pragma omp parallel num_threads(n_threads)
19680
- {
19681
- ggml_graph_compute_thread(&workers[omp_get_thread_num()]);
19673
+ if (n_threads > 1) {
19674
+ #pragma omp parallel num_threads(n_threads)
19675
+ {
19676
+ #pragma omp single
19677
+ {
19678
+ // update the number of threads from the actual number of threads that we got from OpenMP
19679
+ n_threads = omp_get_num_threads();
19680
+ workers[0].shared->n_threads = n_threads;
19681
+ workers[0].shared->n_active = n_threads;
19682
+ }
19683
+ ggml_graph_compute_thread(&workers[omp_get_thread_num()]);
19684
+ }
19685
+ } else {
19686
+ ggml_graph_compute_thread(&workers[0]);
19682
19687
}
19683
19688
#else
19684
19689
// create thread pool
@@ -19724,7 +19729,12 @@ enum ggml_status ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cpl
19724
19729
}
19725
19730
}
19726
19731
19727
- const int n_threads = cplan->n_threads;
19732
+ int n_threads = cplan->n_threads;
19733
+
19734
+ #if defined(GGML_USE_OPENMP)
19735
+ n_threads = MIN(n_threads, omp_get_max_threads());
19736
+ n_threads = MIN(n_threads, omp_get_thread_limit());
19737
+ #endif
19728
19738
19729
19739
struct ggml_compute_state_shared state_shared = {
19730
19740
/*.cgraph =*/ cgraph,
0 commit comments