Skip to content

Commit fa864af

Browse files
committed
update shared state n_threads in parallel region
1 parent 7918ed7 commit fa864af

File tree

1 file changed

+21
-11
lines changed

1 file changed

+21
-11
lines changed

ggml.c

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1751,7 +1751,7 @@ struct ggml_compute_state_shared {
17511751
int64_t perf_node_start_cycles;
17521752
int64_t perf_node_start_time_us;
17531753

1754-
const int n_threads;
1754+
int n_threads;
17551755

17561756
// synchronization primitives
17571757
atomic_int n_active; // num active threads
@@ -19486,12 +19486,6 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
1948619486
if (n_threads <= 0) {
1948719487
n_threads = GGML_DEFAULT_N_THREADS;
1948819488
}
19489-
#if defined(GGML_USE_OPENMP)
19490-
// Limit the number of threads used to avoid deadlock
19491-
// ref: https://github.com/ggerganov/llama.cpp/pull/7606
19492-
n_threads = MIN(n_threads, omp_get_max_threads());
19493-
n_threads = MIN(n_threads, omp_get_thread_limit());
19494-
#endif
1949519489

1949619490
size_t work_size = 0;
1949719491

@@ -19676,9 +19670,20 @@ static enum ggml_status ggml_graph_compute_parallel(struct ggml_compute_state *
1967619670
enum ggml_status compute_status = GGML_STATUS_SUCCESS;
1967719671

1967819672
#ifdef GGML_USE_OPENMP
19679-
#pragma omp parallel num_threads(n_threads)
19680-
{
19681-
ggml_graph_compute_thread(&workers[omp_get_thread_num()]);
19673+
if (n_threads > 1) {
19674+
#pragma omp parallel num_threads(n_threads)
19675+
{
19676+
#pragma omp single
19677+
{
19678+
// update the number of threads from the actual number of threads that we got from OpenMP
19679+
n_threads = omp_get_num_threads();
19680+
workers[0].shared->n_threads = n_threads;
19681+
workers[0].shared->n_active = n_threads;
19682+
}
19683+
ggml_graph_compute_thread(&workers[omp_get_thread_num()]);
19684+
}
19685+
} else {
19686+
ggml_graph_compute_thread(&workers[0]);
1968219687
}
1968319688
#else
1968419689
// create thread pool
@@ -19724,7 +19729,12 @@ enum ggml_status ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cpl
1972419729
}
1972519730
}
1972619731

19727-
const int n_threads = cplan->n_threads;
19732+
int n_threads = cplan->n_threads;
19733+
19734+
#if defined(GGML_USE_OPENMP)
19735+
n_threads = MIN(n_threads, omp_get_max_threads());
19736+
n_threads = MIN(n_threads, omp_get_thread_limit());
19737+
#endif
1972819738

1972919739
struct ggml_compute_state_shared state_shared = {
1973019740
/*.cgraph =*/ cgraph,

0 commit comments

Comments
 (0)