|
20 | 20 | #include <unordered_map>
|
21 | 21 | #include <memory>
|
22 | 22 | #include <mutex>
|
| 23 | +#include <future> |
| 24 | +#include <thread> |
23 | 25 |
|
24 | 26 | #include "ggml-impl.h"
|
25 | 27 | #include "ggml-backend-impl.h"
|
@@ -607,13 +609,16 @@ typedef void (*ggml_vk_func_t)(ggml_backend_vk_context * ctx, vk_context& subctx
|
607 | 609 |
|
608 | 610 | GGML_CALL static void ggml_backend_vk_free(ggml_backend_t backend);
|
609 | 611 |
|
610 |
| -static void ggml_vk_create_pipeline(vk_device& device, vk_pipeline& pipeline, const std::string& name, size_t spv_size, const void* spv_data, const std::string& entrypoint, uint32_t parameter_count, uint32_t push_constant_size, std::array<uint32_t, 3> wg_denoms, std::vector<uint32_t>&& specialization_constants, uint32_t align) { |
| 612 | +// variables to track number of compiles in progress |
| 613 | +static uint32_t compile_count = 0; |
| 614 | +static std::mutex compile_count_mutex; |
| 615 | +static std::condition_variable compile_count_cond; |
| 616 | + |
| 617 | +static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipeline, const std::string name, size_t spv_size, const void* spv_data, const std::string entrypoint, uint32_t parameter_count, uint32_t push_constant_size, std::array<uint32_t, 3> wg_denoms, std::vector<uint32_t> specialization_constants, uint32_t align) { |
611 | 618 | VK_LOG_DEBUG("ggml_vk_create_pipeline(" << device->name << ", " << name << ", " << entrypoint << ", " << parameter_count << ", " << push_constant_size << ", (" << wg_denoms[0] << "," << wg_denoms[1] << "," << wg_denoms[2] << "), specialization_constants, " << align << ")");
|
612 | 619 | GGML_ASSERT(parameter_count > 0);
|
613 | 620 | GGML_ASSERT(wg_denoms[0] > 0 && wg_denoms[1] > 0 && wg_denoms[2] > 0); // NOLINT
|
614 | 621 |
|
615 |
| - std::lock_guard<std::mutex> guard(device->mutex); |
616 |
| - |
617 | 622 | pipeline = std::make_shared<vk_pipeline_struct>();
|
618 | 623 | pipeline->name = name;
|
619 | 624 | pipeline->parameter_count = parameter_count;
|
@@ -681,7 +686,17 @@ static void ggml_vk_create_pipeline(vk_device& device, vk_pipeline& pipeline, co
|
681 | 686 | pipeline->layout);
|
682 | 687 | pipeline->pipeline = device->device.createComputePipeline(VK_NULL_HANDLE, compute_pipeline_create_info).value;
|
683 | 688 |
|
684 |
| - device->pipelines.insert({ pipeline->name, pipeline }); |
| 689 | + { |
| 690 | + std::lock_guard<std::mutex> guard(device->mutex); |
| 691 | + device->pipelines.insert({ pipeline->name, pipeline }); |
| 692 | + } |
| 693 | + |
| 694 | + { |
| 695 | + std::lock_guard<std::mutex> guard(compile_count_mutex); |
| 696 | + assert(compile_count > 0); |
| 697 | + compile_count--; |
| 698 | + } |
| 699 | + compile_count_cond.notify_all(); |
685 | 700 | }
|
686 | 701 |
|
687 | 702 | static void ggml_vk_destroy_pipeline(vk::Device& device, vk_pipeline& pipeline) {
|
@@ -1194,6 +1209,20 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
1194 | 1209 | device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K] = std::make_shared<vk_matmul_pipeline_struct>();
|
1195 | 1210 | device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL] = std::make_shared<vk_matmul_pipeline_struct>();
|
1196 | 1211 |
|
| 1212 | + std::vector<std::future<void>> compiles; |
| 1213 | + auto const &ggml_vk_create_pipeline = [&](vk_device& device, vk_pipeline& pipeline, const std::string &name, size_t spv_size, const void* spv_data, const std::string &entrypoint, uint32_t parameter_count, uint32_t push_constant_size, std::array<uint32_t, 3> wg_denoms, std::vector<uint32_t>&& specialization_constants, uint32_t align) { |
| 1214 | + { |
| 1215 | + // wait until fewer than N compiles are in progress |
| 1216 | + uint32_t N = std::max(1u, std::thread::hardware_concurrency()); |
| 1217 | + std::unique_lock<std::mutex> guard(compile_count_mutex); |
| 1218 | + while (compile_count >= N) { |
| 1219 | + compile_count_cond.wait(guard); |
| 1220 | + } |
| 1221 | + compile_count++; |
| 1222 | + } |
| 1223 | + compiles.push_back(std::async(ggml_vk_create_pipeline_func, std::ref(device), std::ref(pipeline), name, spv_size, spv_data, entrypoint, parameter_count, push_constant_size, wg_denoms, specialization_constants, align)); |
| 1224 | + }; |
| 1225 | + |
1197 | 1226 | if (device->fp16) {
|
1198 | 1227 | ggml_vk_create_pipeline(device, device->pipeline_matmul_f32->l, "matmul_f32_l", matmul_f32_f32_len, matmul_f32_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, 1);
|
1199 | 1228 | ggml_vk_create_pipeline(device, device->pipeline_matmul_f32->m, "matmul_f32_m", matmul_f32_f32_len, matmul_f32_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, 1);
|
@@ -1743,6 +1772,10 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
1743 | 1772 | ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_len, im2col_f32_f16_data, "main", 2, sizeof(vk_op_im2col_push_constants), {256, 1, 1}, {}, 1);
|
1744 | 1773 |
|
1745 | 1774 | ggml_vk_create_pipeline(device, device->pipeline_timestep_embedding_f32, "timestep_embedding_f32", timestep_embedding_f32_len, timestep_embedding_f32_data, "main", 2, sizeof(vk_op_timestep_embedding_push_constants), {256, 1, 1}, {}, 1);
|
| 1775 | + |
| 1776 | + for (auto &c : compiles) { |
| 1777 | + c.wait(); |
| 1778 | + } |
1746 | 1779 | }
|
1747 | 1780 |
|
1748 | 1781 | static vk_device ggml_vk_get_device(size_t idx) {
|
|
0 commit comments