Skip to content

threads: changing to a mutex/condvar based thread pool. #710

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,10 @@ endif()

add_library(ggml OBJECT
ggml.c
ggml.h)
ggml.h
thpool.c
thpool.h
)

target_include_directories(ggml PUBLIC .)
target_compile_features(ggml PUBLIC c_std_11) # don't bump
Expand Down
17 changes: 10 additions & 7 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,9 @@ default: main quantize perplexity embedding
# Build library
#

thpool.o: thpool.c thpool.h
$(CC) $(CFLAGS) -c thpool.c -o thpool.o

ggml.o: ggml.c ggml.h
$(CC) $(CFLAGS) -c ggml.c -o ggml.o

Expand All @@ -150,20 +153,20 @@ common.o: examples/common.cpp examples/common.h
clean:
rm -vf *.o main quantize perplexity embedding

main: examples/main/main.cpp ggml.o llama.o common.o
$(CXX) $(CXXFLAGS) examples/main/main.cpp ggml.o llama.o common.o -o main $(LDFLAGS)
main: examples/main/main.cpp thpool.o ggml.o llama.o common.o
$(CXX) $(CXXFLAGS) examples/main/main.cpp thpool.o ggml.o llama.o common.o -o main $(LDFLAGS)
@echo
@echo '==== Run ./main -h for help. ===='
@echo

quantize: examples/quantize/quantize.cpp ggml.o llama.o
$(CXX) $(CXXFLAGS) examples/quantize/quantize.cpp ggml.o llama.o -o quantize $(LDFLAGS)
quantize: examples/quantize/quantize.cpp thpool.o ggml.o llama.o
$(CXX) $(CXXFLAGS) examples/quantize/quantize.cpp thpool.o ggml.o llama.o -o quantize $(LDFLAGS)

perplexity: examples/perplexity/perplexity.cpp ggml.o llama.o common.o
$(CXX) $(CXXFLAGS) examples/perplexity/perplexity.cpp ggml.o llama.o common.o -o perplexity $(LDFLAGS)
$(CXX) $(CXXFLAGS) examples/perplexity/perplexity.cpp thpool.o ggml.o llama.o common.o -o perplexity $(LDFLAGS)

embedding: examples/embedding/embedding.cpp ggml.o llama.o common.o
$(CXX) $(CXXFLAGS) examples/embedding/embedding.cpp ggml.o llama.o common.o -o embedding $(LDFLAGS)
embedding: examples/embedding/embedding.cpp thpool.o ggml.o llama.o common.o
$(CXX) $(CXXFLAGS) examples/embedding/embedding.cpp thpool.o ggml.o llama.o common.o -o embedding $(LDFLAGS)

#
# Tests
Expand Down
199 changes: 39 additions & 160 deletions ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

#include "ggml.h"

#include "thpool.h"

#if defined(_MSC_VER) || defined(__MINGW32__)
#include <malloc.h> // using malloc.h with MSC/MINGW
#elif !defined(__FreeBSD__) && !defined(__NetBSD__) && !defined(__OpenBSD__)
Expand Down Expand Up @@ -51,28 +53,11 @@ static LONG atomic_fetch_sub(atomic_int* ptr, LONG dec) {
return atomic_fetch_add(ptr, -(dec));
}

typedef HANDLE pthread_t;

typedef DWORD thread_ret_t;
static int pthread_create(pthread_t* out, void* unused, thread_ret_t(*func)(void*), void* arg) {
HANDLE handle = CreateThread(NULL, 0, (LPTHREAD_START_ROUTINE) func, arg, 0, NULL);
if (handle == NULL)
{
return EAGAIN;
}

*out = handle;
return 0;
}

static int pthread_join(pthread_t thread, void* unused) {
return (int) WaitForSingleObject(thread, INFINITE);
}

static int sched_yield (void) {
Sleep (0);
return 0;
}

#else
#include <pthread.h>
#include <stdatomic.h>
Expand Down Expand Up @@ -2697,6 +2682,7 @@ struct ggml_context {

struct ggml_scratch scratch;
struct ggml_scratch scratch_save;
threadpool tpool;
};

struct ggml_context_container {
Expand Down Expand Up @@ -2981,6 +2967,7 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
/*.objects_end =*/ NULL,
/*.scratch =*/ { 0, 0, NULL, },
/*.scratch_save =*/ { 0, 0, NULL, },
/*.thpool =*/ NULL,
};

GGML_ASSERT(ctx->mem_buffer != NULL); // check for allocation failure
Expand Down Expand Up @@ -9113,6 +9100,19 @@ typedef pthread_t ggml_thread_t;
#define ggml_thread_create pthread_create
#define ggml_thread_join pthread_join

typedef pthread_mutex_t ggml_mutex_t;
typedef pthread_cond_t ggml_cond_t;

#define ggml_mutex_init pthread_mutex_init
#define ggml_mutex_destroy pthread_mutex_destroy
#define ggml_cond_init pthread_cond_init
#define ggml_cond_destroy pthread_cond_destroy

#define ggml_mutex_lock pthread_mutex_lock
#define ggml_mutex_unlock pthread_mutex_unlock
#define ggml_cond_broadcast pthread_cond_broadcast
#define ggml_cond_wait pthread_cond_wait

#else

//typedef pthread_spinlock_t ggml_lock_t;
Expand All @@ -9131,102 +9131,56 @@ typedef int ggml_lock_t;

#define GGML_LOCK_INITIALIZER 0

typedef pthread_t ggml_thread_t;

#define ggml_thread_create pthread_create
#define ggml_thread_join pthread_join

#define ggml_mutex_init pthread_mutex_init
#define ggml_mutex_destroy pthread_mutex_destroy
#define ggml_cond_init pthread_cond_init
#define ggml_cond_destroy pthread_cond_destroy

#define ggml_mutex_lock pthread_mutex_lock
#define ggml_mutex_unlock pthread_mutex_unlock
#define ggml_cond_broadcast pthread_cond_broadcast
#define ggml_cond_wait pthread_cond_wait

#endif

struct ggml_compute_state_shared {
ggml_lock_t spin;

int n_threads;

// synchronization primitives
atomic_int n_ready;
atomic_bool has_work;
atomic_bool stop; // stop all threads
};

struct ggml_compute_state {
ggml_thread_t thrd;

struct ggml_compute_params params;
struct ggml_tensor * node;

struct ggml_compute_state_shared * shared;
};

static thread_ret_t ggml_graph_compute_thread(void * data) {
static void ggml_graph_compute_thread(void * data) {
struct ggml_compute_state * state = (struct ggml_compute_state *) data;

const int n_threads = state->shared->n_threads;

while (true) {
if (atomic_fetch_add(&state->shared->n_ready, 1) == n_threads - 1) {
atomic_store(&state->shared->has_work, false);
} else {
while (atomic_load(&state->shared->has_work)) {
if (atomic_load(&state->shared->stop)) {
return 0;
}
ggml_lock_lock (&state->shared->spin);
ggml_lock_unlock(&state->shared->spin);
}
}

atomic_fetch_sub(&state->shared->n_ready, 1);

// wait for work
while (!atomic_load(&state->shared->has_work)) {
if (atomic_load(&state->shared->stop)) {
return 0;
}
ggml_lock_lock (&state->shared->spin);
ggml_lock_unlock(&state->shared->spin);
}

// check if we should stop
if (atomic_load(&state->shared->stop)) {
break;
}

if (state->node) {
if (state->params.ith < state->params.nth) {
ggml_compute_forward(&state->params, state->node);
}

state->node = NULL;
} else {
break;
if (state->node) {
if (state->params.ith < state->params.nth) {
ggml_compute_forward(&state->params, state->node);
}
state->node = NULL;
}

return 0;
}

void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) {
const int n_threads = cgraph->n_threads;

struct ggml_compute_state_shared state_shared = {
/*.spin =*/ GGML_LOCK_INITIALIZER,
/*.n_threads =*/ n_threads,
/*.n_ready =*/ 0,
/*.has_work =*/ false,
/*.stop =*/ false,
};
struct ggml_compute_state * workers = n_threads > 1 ? alloca(sizeof(struct ggml_compute_state)*(n_threads - 1)) : NULL;

// create thread pool
if (n_threads > 1) {
ggml_lock_init(&state_shared.spin);

atomic_store(&state_shared.has_work, true);

ctx->tpool = thpool_init(n_threads);
for (int j = 0; j < n_threads - 1; j++) {
workers[j] = (struct ggml_compute_state) {
.thrd = 0,
.params = {
.type = GGML_TASK_COMPUTE,
.ith = j + 1,
Expand All @@ -9237,10 +9191,6 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
.node = NULL,
.shared = &state_shared,
};

int rc = ggml_thread_create(&workers[j].thrd, NULL, ggml_graph_compute_thread, &workers[j]);
GGML_ASSERT(rc == 0);
UNUSED(rc);
}
}

Expand Down Expand Up @@ -9478,15 +9428,6 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)

// COMPUTE
if (node->n_tasks > 1) {
if (atomic_fetch_add(&state_shared.n_ready, 1) == n_threads - 1) {
atomic_store(&state_shared.has_work, false);
}

while (atomic_load(&state_shared.has_work)) {
ggml_lock_lock (&state_shared.spin);
ggml_lock_unlock(&state_shared.spin);
}

// launch thread pool
for (int j = 0; j < n_threads - 1; j++) {
workers[j].params = (struct ggml_compute_params) {
Expand All @@ -9497,51 +9438,20 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
.wdata = cgraph->work ? cgraph->work->data : NULL,
};
workers[j].node = node;
thpool_add_work(ctx->tpool, ggml_graph_compute_thread, &workers[j]);
}

atomic_fetch_sub(&state_shared.n_ready, 1);

while (atomic_load(&state_shared.n_ready) > 0) {
ggml_lock_lock (&state_shared.spin);
ggml_lock_unlock(&state_shared.spin);
}

atomic_store(&state_shared.has_work, true);
}

params.type = GGML_TASK_COMPUTE;
ggml_compute_forward(&params, node);

// wait for thread pool
if (node->n_tasks > 1) {
if (atomic_fetch_add(&state_shared.n_ready, 1) == n_threads - 1) {
atomic_store(&state_shared.has_work, false);
}

while (atomic_load(&state_shared.has_work)) {
ggml_lock_lock (&state_shared.spin);
ggml_lock_unlock(&state_shared.spin);
}

atomic_fetch_sub(&state_shared.n_ready, 1);

while (atomic_load(&state_shared.n_ready) != 0) {
ggml_lock_lock (&state_shared.spin);
ggml_lock_unlock(&state_shared.spin);
}
thpool_wait(ctx->tpool);
}

// FINALIZE
if (node->n_tasks > 1) {
if (atomic_fetch_add(&state_shared.n_ready, 1) == n_threads - 1) {
atomic_store(&state_shared.has_work, false);
}

while (atomic_load(&state_shared.has_work)) {
ggml_lock_lock (&state_shared.spin);
ggml_lock_unlock(&state_shared.spin);
}

// launch thread pool
for (int j = 0; j < n_threads - 1; j++) {
workers[j].params = (struct ggml_compute_params) {
Expand All @@ -9552,38 +9462,16 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
.wdata = cgraph->work ? cgraph->work->data : NULL,
};
workers[j].node = node;
thpool_add_work(ctx->tpool, ggml_graph_compute_thread, &workers[j]);
}

atomic_fetch_sub(&state_shared.n_ready, 1);

while (atomic_load(&state_shared.n_ready) > 0) {
ggml_lock_lock (&state_shared.spin);
ggml_lock_unlock(&state_shared.spin);
}

atomic_store(&state_shared.has_work, true);
}

params.type = GGML_TASK_FINALIZE;
ggml_compute_forward(&params, node);

// wait for thread pool
if (node->n_tasks > 1) {
if (atomic_fetch_add(&state_shared.n_ready, 1) == n_threads - 1) {
atomic_store(&state_shared.has_work, false);
}

while (atomic_load(&state_shared.has_work)) {
ggml_lock_lock (&state_shared.spin);
ggml_lock_unlock(&state_shared.spin);
}

atomic_fetch_sub(&state_shared.n_ready, 1);

while (atomic_load(&state_shared.n_ready) != 0) {
ggml_lock_lock (&state_shared.spin);
ggml_lock_unlock(&state_shared.spin);
}
thpool_wait(ctx->tpool);
}

// performance stats (node)
Expand All @@ -9599,16 +9487,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)

// join thread pool
if (n_threads > 1) {
atomic_store(&state_shared.stop, true);
atomic_store(&state_shared.has_work, true);

for (int j = 0; j < n_threads - 1; j++) {
int rc = ggml_thread_join(workers[j].thrd, NULL);
GGML_ASSERT(rc == 0);
UNUSED(rc);
}

ggml_lock_destroy(&state_shared.spin);
thpool_destroy(ctx->tpool);
}

// performance stats (graph)
Expand Down
Loading