Skip to content

Commit 641002f

Browse files
jeffbolznvggerganov
authored andcommitted
vulkan : multithread pipeline creation (ggml/963)
1 parent 0de8b20 commit 641002f

File tree

1 file changed

+37
-4
lines changed

1 file changed

+37
-4
lines changed

ggml/src/ggml-vulkan.cpp

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
#include <unordered_map>
2121
#include <memory>
2222
#include <mutex>
23+
#include <future>
24+
#include <thread>
2325

2426
#include "ggml-impl.h"
2527
#include "ggml-backend-impl.h"
@@ -607,13 +609,16 @@ typedef void (*ggml_vk_func_t)(ggml_backend_vk_context * ctx, vk_context& subctx
607609

608610
GGML_CALL static void ggml_backend_vk_free(ggml_backend_t backend);
609611

610-
static void ggml_vk_create_pipeline(vk_device& device, vk_pipeline& pipeline, const std::string& name, size_t spv_size, const void* spv_data, const std::string& entrypoint, uint32_t parameter_count, uint32_t push_constant_size, std::array<uint32_t, 3> wg_denoms, std::vector<uint32_t>&& specialization_constants, uint32_t align) {
612+
// variables to track number of compiles in progress
613+
static uint32_t compile_count = 0;
614+
static std::mutex compile_count_mutex;
615+
static std::condition_variable compile_count_cond;
616+
617+
static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipeline, const std::string name, size_t spv_size, const void* spv_data, const std::string entrypoint, uint32_t parameter_count, uint32_t push_constant_size, std::array<uint32_t, 3> wg_denoms, std::vector<uint32_t> specialization_constants, uint32_t align) {
611618
VK_LOG_DEBUG("ggml_vk_create_pipeline(" << device->name << ", " << name << ", " << entrypoint << ", " << parameter_count << ", " << push_constant_size << ", (" << wg_denoms[0] << "," << wg_denoms[1] << "," << wg_denoms[2] << "), specialization_constants, " << align << ")");
612619
GGML_ASSERT(parameter_count > 0);
613620
GGML_ASSERT(wg_denoms[0] > 0 && wg_denoms[1] > 0 && wg_denoms[2] > 0); // NOLINT
614621

615-
std::lock_guard<std::mutex> guard(device->mutex);
616-
617622
pipeline = std::make_shared<vk_pipeline_struct>();
618623
pipeline->name = name;
619624
pipeline->parameter_count = parameter_count;
@@ -681,7 +686,17 @@ static void ggml_vk_create_pipeline(vk_device& device, vk_pipeline& pipeline, co
681686
pipeline->layout);
682687
pipeline->pipeline = device->device.createComputePipeline(VK_NULL_HANDLE, compute_pipeline_create_info).value;
683688

684-
device->pipelines.insert({ pipeline->name, pipeline });
689+
{
690+
std::lock_guard<std::mutex> guard(device->mutex);
691+
device->pipelines.insert({ pipeline->name, pipeline });
692+
}
693+
694+
{
695+
std::lock_guard<std::mutex> guard(compile_count_mutex);
696+
assert(compile_count > 0);
697+
compile_count--;
698+
}
699+
compile_count_cond.notify_all();
685700
}
686701

687702
static void ggml_vk_destroy_pipeline(vk::Device& device, vk_pipeline& pipeline) {
@@ -1194,6 +1209,20 @@ static void ggml_vk_load_shaders(vk_device& device) {
11941209
device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K] = std::make_shared<vk_matmul_pipeline_struct>();
11951210
device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL] = std::make_shared<vk_matmul_pipeline_struct>();
11961211

1212+
std::vector<std::future<void>> compiles;
1213+
auto const &ggml_vk_create_pipeline = [&](vk_device& device, vk_pipeline& pipeline, const std::string &name, size_t spv_size, const void* spv_data, const std::string &entrypoint, uint32_t parameter_count, uint32_t push_constant_size, std::array<uint32_t, 3> wg_denoms, std::vector<uint32_t>&& specialization_constants, uint32_t align) {
1214+
{
1215+
// wait until fewer than N compiles are in progress
1216+
uint32_t N = std::max(1u, std::thread::hardware_concurrency());
1217+
std::unique_lock<std::mutex> guard(compile_count_mutex);
1218+
while (compile_count >= N) {
1219+
compile_count_cond.wait(guard);
1220+
}
1221+
compile_count++;
1222+
}
1223+
compiles.push_back(std::async(ggml_vk_create_pipeline_func, std::ref(device), std::ref(pipeline), name, spv_size, spv_data, entrypoint, parameter_count, push_constant_size, wg_denoms, specialization_constants, align));
1224+
};
1225+
11971226
if (device->fp16) {
11981227
ggml_vk_create_pipeline(device, device->pipeline_matmul_f32->l, "matmul_f32_l", matmul_f32_f32_len, matmul_f32_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, 1);
11991228
ggml_vk_create_pipeline(device, device->pipeline_matmul_f32->m, "matmul_f32_m", matmul_f32_f32_len, matmul_f32_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, 1);
@@ -1743,6 +1772,10 @@ static void ggml_vk_load_shaders(vk_device& device) {
17431772
ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_len, im2col_f32_f16_data, "main", 2, sizeof(vk_op_im2col_push_constants), {256, 1, 1}, {}, 1);
17441773

17451774
ggml_vk_create_pipeline(device, device->pipeline_timestep_embedding_f32, "timestep_embedding_f32", timestep_embedding_f32_len, timestep_embedding_f32_data, "main", 2, sizeof(vk_op_timestep_embedding_push_constants), {256, 1, 1}, {}, 1);
1775+
1776+
for (auto &c : compiles) {
1777+
c.wait();
1778+
}
17461779
}
17471780

17481781
static vk_device ggml_vk_get_device(size_t idx) {

0 commit comments

Comments
 (0)