Skip to content

Commit 58830fd

Browse files
jeffbolznvmglambda
authored andcommitted
vulkan: compile shaders on-demand (ggml-org#11406)
Reduce first-run startup time and memory consumption. Should fix ggml-org#11339.
1 parent ad755d8 commit 58830fd

File tree

1 file changed

+41
-23
lines changed

1 file changed

+41
-23
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 41 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,10 @@ struct vk_pipeline_struct {
8585
uint32_t parameter_count;
8686
std::array<uint32_t, 3> wg_denoms;
8787
uint32_t align;
88+
// set to true to request the pipeline is compiled after the dryrun
89+
bool needed {};
90+
// set to true when the shader has been compiled
91+
bool compiled {};
8892
};
8993

9094
typedef std::shared_ptr<vk_pipeline_struct> vk_pipeline;
@@ -186,16 +190,19 @@ struct vk_device_struct {
186190
bool mul_mat_id_m;
187191
bool mul_mat_id_s;
188192

189-
vk_matmul_pipeline pipeline_matmul_f32;
190-
vk_matmul_pipeline pipeline_matmul_f32_f16;
193+
// set to true to indicate that some shaders need to be compiled after the dryrun
194+
bool need_compiles {};
195+
196+
vk_matmul_pipeline pipeline_matmul_f32 {};
197+
vk_matmul_pipeline pipeline_matmul_f32_f16 {};
191198
vk_matmul_pipeline2 pipeline_matmul_f16;
192199
vk_matmul_pipeline2 pipeline_matmul_f16_f32;
193200
vk_pipeline pipeline_matmul_split_k_reduce;
194201

195202
vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_COUNT];
196203
vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat[GGML_TYPE_COUNT];
197204

198-
vk_matmul_pipeline pipeline_matmul_id_f32;
205+
vk_matmul_pipeline pipeline_matmul_id_f32 {};
199206
vk_matmul_pipeline2 pipeline_matmul_id_f16;
200207
vk_matmul_pipeline2 pipeline_matmul_id_f16_f32;
201208

@@ -776,13 +783,6 @@ static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipelin
776783
GGML_ASSERT(parameter_count > 0);
777784
GGML_ASSERT(wg_denoms[0] > 0 && wg_denoms[1] > 0 && wg_denoms[2] > 0); // NOLINT
778785

779-
pipeline = std::make_shared<vk_pipeline_struct>();
780-
pipeline->name = name;
781-
pipeline->parameter_count = parameter_count;
782-
pipeline->push_constant_size = push_constant_size;
783-
pipeline->wg_denoms = wg_denoms;
784-
pipeline->align = align;
785-
786786
vk::ShaderModuleCreateInfo shader_module_create_info({}, spv_size, reinterpret_cast<const uint32_t *>(spv_data));
787787
pipeline->shader_module = device->device.createShaderModule(shader_module_create_info);
788788

@@ -865,6 +865,7 @@ static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipelin
865865
}
866866

867867
pipeline->pipeline = device->device.createComputePipeline(VK_NULL_HANDLE, compute_pipeline_create_info).value;
868+
pipeline->compiled = true;
868869

869870
{
870871
std::lock_guard<std::mutex> guard(device->mutex);
@@ -875,12 +876,6 @@ static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipelin
875876
std::lock_guard<std::mutex> guard(compile_count_mutex);
876877
assert(compile_count > 0);
877878
compile_count--;
878-
879-
// "Progress bar" for shader compiles
880-
static uint32_t total_compile_count = 0;
881-
if ((total_compile_count++ % 10) == 0) {
882-
std::cerr << ".";
883-
}
884879
}
885880
compile_count_cond.notify_all();
886881
}
@@ -906,6 +901,10 @@ static void ggml_vk_destroy_pipeline(vk::Device& device, vk_pipeline& pipeline)
906901
static void ggml_pipeline_request_descriptor_sets(vk_device& device, vk_pipeline& pipeline, uint32_t n) {
907902
VK_LOG_DEBUG("ggml_pipeline_request_descriptor_sets(" << pipeline->name << ", " << n << ")");
908903
device->pipeline_descriptor_set_requirements[pipeline->name] += n;
904+
if (!pipeline->compiled) {
905+
pipeline->needed = true;
906+
device->need_compiles = true;
907+
}
909908
}
910909

911910
static void ggml_pipeline_allocate_descriptor_sets(vk_device& device) {
@@ -1388,8 +1387,6 @@ static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vec
13881387
static void ggml_vk_load_shaders(vk_device& device) {
13891388
VK_LOG_DEBUG("ggml_vk_load_shaders(" << device->name << ")");
13901389

1391-
std::cerr << "ggml_vulkan: Compiling shaders";
1392-
13931390
// some shaders have a minimum subgroup size
13941391
const uint32_t subgroup_size_16 = std::max(device->subgroup_size, 16u);
13951392
const uint32_t subgroup_size_32 = std::max(device->subgroup_size, 32u);
@@ -1527,15 +1524,33 @@ static void ggml_vk_load_shaders(vk_device& device) {
15271524
}
15281525
}
15291526

1530-
device->pipeline_matmul_f32 = std::make_shared<vk_matmul_pipeline_struct>();
1531-
device->pipeline_matmul_f32_f16 = std::make_shared<vk_matmul_pipeline_struct>();
1532-
1533-
device->pipeline_matmul_id_f32 = std::make_shared<vk_matmul_pipeline_struct>();
1527+
if (!device->pipeline_matmul_f32) {
1528+
device->pipeline_matmul_f32 = std::make_shared<vk_matmul_pipeline_struct>();
1529+
}
1530+
if (!device->pipeline_matmul_f32_f16) {
1531+
device->pipeline_matmul_f32_f16 = std::make_shared<vk_matmul_pipeline_struct>();
1532+
}
1533+
if (!device->pipeline_matmul_id_f32) {
1534+
device->pipeline_matmul_id_f32 = std::make_shared<vk_matmul_pipeline_struct>();
1535+
}
15341536

15351537
std::vector<std::future<void>> compiles;
15361538
auto const &ggml_vk_create_pipeline = [&](vk_device& device, vk_pipeline& pipeline, const std::string &name, size_t spv_size, const void* spv_data, const std::string &entrypoint,
15371539
uint32_t parameter_count, uint32_t push_constant_size, std::array<uint32_t, 3> wg_denoms, const std::vector<uint32_t>& specialization_constants,
15381540
uint32_t align, bool disable_robustness = false, bool require_full_subgroups = false, uint32_t required_subgroup_size = 0) {
1541+
1542+
if (!pipeline) {
1543+
pipeline = std::make_shared<vk_pipeline_struct>();
1544+
pipeline->name = name;
1545+
pipeline->parameter_count = parameter_count;
1546+
pipeline->push_constant_size = push_constant_size;
1547+
pipeline->wg_denoms = wg_denoms;
1548+
pipeline->align = align;
1549+
}
1550+
1551+
if (!pipeline->needed || pipeline->compiled) {
1552+
return;
1553+
}
15391554
{
15401555
// wait until fewer than N compiles are in progress
15411556
uint32_t N = std::max(1u, std::thread::hardware_concurrency());
@@ -2050,7 +2065,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
20502065
for (auto &c : compiles) {
20512066
c.wait();
20522067
}
2053-
std::cerr << "Done!" << std::endl;
2068+
device->need_compiles = false;
20542069
}
20552070

20562071
static bool ggml_vk_khr_cooperative_matrix_support(const vk::PhysicalDeviceProperties& props, const vk::PhysicalDeviceDriverProperties& driver_props);
@@ -7656,6 +7671,9 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
76567671
for (int i = 0; i < cgraph->n_nodes; i++) {
76577672
ggml_vk_build_graph(ctx, cgraph->nodes[i], i, nullptr, 0, true, false, false);
76587673
}
7674+
if (ctx->device->need_compiles) {
7675+
ggml_vk_load_shaders(ctx->device);
7676+
}
76597677
ggml_vk_preallocate_buffers(ctx);
76607678
ggml_pipeline_allocate_descriptor_sets(ctx->device);
76617679

0 commit comments

Comments
 (0)