Skip to content

Commit 8994ac8

Browse files
committed
@0cc4m feedback
1 parent 99b8b43 commit 8994ac8

File tree

1 file changed

+48
-27
lines changed

1 file changed

+48
-27
lines changed

ggml-vulkan.cpp

Lines changed: 48 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -707,9 +707,21 @@ static void ggml_vk_queue_cleanup(ggml_backend_vk_context * ctx, vk_queue& q) {
707707
q.cmd_buffer_idx = 0;
708708
}
709709

710-
static vk_buffer ggml_vk_create_buffer(ggml_backend_vk_context * ctx, size_t size, vk::MemoryPropertyFlags req_flags) {
710+
static int32_t find_properties(const vk::PhysicalDeviceMemoryProperties* mem_props, vk::MemoryRequirements* mem_req, vk::MemoryPropertyFlags flags) {
711+
for (uint32_t i = 0; i < mem_props->memoryTypeCount; ++i) {
712+
vk::MemoryType memory_type = mem_props->memoryTypes[i];
713+
if ((mem_req->memoryTypeBits & ((uint64_t)1 << i)) &&
714+
(flags & memory_type.propertyFlags) == flags &&
715+
mem_props->memoryHeaps[memory_type.heapIndex].size >= mem_req->size) {
716+
return static_cast<int32_t>(i);
717+
}
718+
}
719+
return -1;
720+
}
721+
722+
static vk_buffer ggml_vk_create_buffer(ggml_backend_vk_context * ctx, size_t size, vk::MemoryPropertyFlags req_flags, vk::MemoryPropertyFlags fallback_flags = vk::MemoryPropertyFlags(0)) {
711723
#ifdef GGML_VULKAN_DEBUG
712-
std::cerr << "ggml_vk_create_buffer(" << size << ", " << to_string(req_flags) << ")" << std::endl;
724+
std::cerr << "ggml_vk_create_buffer(" << size << ", " << to_string(req_flags) << ", " << to_string(fallback_flags) << ")" << std::endl;
713725
#endif
714726
vk_buffer buf = std::make_shared<vk_buffer_struct>();
715727

@@ -734,17 +746,22 @@ static vk_buffer ggml_vk_create_buffer(ggml_backend_vk_context * ctx, size_t siz
734746

735747
vk::PhysicalDeviceMemoryProperties mem_props = ctx->device.lock()->physical_device.getMemoryProperties();
736748

737-
uint32_t memory_type_index = UINT32_MAX;
749+
uint32_t memory_type_index = -1;
738750

739-
for (uint32_t i = 0; i < mem_props.memoryTypeCount; ++i) {
740-
vk::MemoryType memory_type = mem_props.memoryTypes[i];
741-
if ((mem_req.memoryTypeBits & ((uint64_t)1 << i)) && (req_flags & memory_type.propertyFlags) == req_flags && mem_props.memoryHeaps[memory_type.heapIndex].size >= mem_req.size) {
742-
memory_type_index = i;
743-
break;
744-
}
751+
memory_type_index = find_properties(&mem_props, &mem_req, req_flags);
752+
buf->memory_property_flags = req_flags;
753+
754+
// Failed to find memory type matching req_flags, but we can try again with fallback_flags if specified...
755+
if (memory_type_index == -1 && fallback_flags && (
756+
// ...as long as req_flags was either: 1) not DEVICE_LOCAL; or 2) DEVICE_LOCAL and device has UMA.
757+
!(req_flags & vk::MemoryPropertyFlagBits::eDeviceLocal) ||
758+
(req_flags & vk::MemoryPropertyFlagBits::eDeviceLocal && ctx->device.lock()->uma))
759+
) {
760+
memory_type_index = find_properties(&mem_props, &mem_req, fallback_flags);
761+
buf->memory_property_flags = fallback_flags;
745762
}
746763

747-
if (memory_type_index >= mem_props.memoryTypeCount) {
764+
if (memory_type_index == -1) {
748765
ctx->device.lock()->device.destroyBuffer(buf->buffer);
749766
buf->size = 0;
750767
throw vk::OutOfDeviceMemoryError("No suitable memory type found");
@@ -758,10 +775,9 @@ static vk_buffer ggml_vk_create_buffer(ggml_backend_vk_context * ctx, size_t siz
758775
buf->size = 0;
759776
throw e;
760777
}
761-
buf->memory_property_flags = req_flags;
762778
buf->ptr = nullptr;
763779

764-
if (req_flags & vk::MemoryPropertyFlagBits::eHostVisible) {
780+
if (buf->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible) {
765781
buf->ptr = ctx->device.lock()->device.mapMemory(buf->device_memory, 0, VK_WHOLE_SIZE);
766782
}
767783

@@ -778,9 +794,9 @@ static vk_buffer ggml_vk_create_buffer(ggml_backend_vk_context * ctx, size_t siz
778794
return buf;
779795
}
780796

781-
static vk_buffer ggml_vk_create_buffer_check(ggml_backend_vk_context * ctx, size_t size, vk::MemoryPropertyFlags req_flags) {
797+
static vk_buffer ggml_vk_create_buffer_check(ggml_backend_vk_context * ctx, size_t size, vk::MemoryPropertyFlags req_flags, vk::MemoryPropertyFlags fallback_flags = vk::MemoryPropertyFlags(0)) {
782798
try {
783-
return ggml_vk_create_buffer(ctx, size, req_flags);
799+
return ggml_vk_create_buffer(ctx, size, req_flags, fallback_flags);
784800
} catch (const vk::SystemError& e) {
785801
std::cerr << "ggml_vulkan: Memory allocation of size " << size << " failed." << std::endl;
786802
std::cerr << "ggml_vulkan: " << e.what() << std::endl;
@@ -791,16 +807,13 @@ static vk_buffer ggml_vk_create_buffer_check(ggml_backend_vk_context * ctx, size
791807
static vk_buffer ggml_vk_create_buffer_device(ggml_backend_vk_context * ctx, size_t size) {
792808
vk_buffer buf;
793809
try {
794-
buf = ggml_vk_create_buffer(ctx, size, vk::MemoryPropertyFlagBits::eDeviceLocal);
810+
buf = ggml_vk_create_buffer(ctx, size,
811+
vk::MemoryPropertyFlagBits::eDeviceLocal,
812+
vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent);
795813
} catch (const vk::SystemError& e) {
796-
if (ctx->device.lock()->uma) {
797-
// Fall back to host memory type
798-
buf = ggml_vk_create_buffer_check(ctx, size, vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent);
799-
} else {
800-
std::cerr << "ggml_vulkan: Device memory allocation of size " << size << " failed." << std::endl;
801-
std::cerr << "ggml_vulkan: " << e.what() << std::endl;
802-
throw e;
803-
}
814+
std::cerr << "ggml_vulkan: Device memory allocation of size " << size << " failed." << std::endl;
815+
std::cerr << "ggml_vulkan: " << e.what() << std::endl;
816+
throw e;
804817
}
805818

806819
return buf;
@@ -1422,7 +1435,9 @@ static void * ggml_vk_host_malloc(ggml_backend_vk_context * ctx, size_t size) {
14221435
#ifdef GGML_VULKAN_DEBUG
14231436
std::cerr << "ggml_vk_host_malloc(" << size << ")" << std::endl;
14241437
#endif
1425-
vk_buffer buf = ggml_vk_create_buffer(ctx, size, vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent | vk::MemoryPropertyFlagBits::eHostCached);
1438+
vk_buffer buf = ggml_vk_create_buffer(ctx, size,
1439+
vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent | vk::MemoryPropertyFlagBits::eHostCached,
1440+
vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent);
14261441

14271442
if(!(buf->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible)) {
14281443
fprintf(stderr, "WARNING: failed to allocate %.2f MB of pinned memory\n",
@@ -1568,7 +1583,9 @@ static void deferred_memcpy(void * dst, const void * src, size_t size, std::vect
15681583
static void ggml_vk_ensure_sync_staging_buffer(ggml_backend_vk_context * ctx, size_t size) {
15691584
if (ctx->sync_staging == nullptr || ctx->sync_staging->size < size) {
15701585
ggml_vk_destroy_buffer(ctx->sync_staging);
1571-
ctx->sync_staging = ggml_vk_create_buffer_check(ctx, size, vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent | vk::MemoryPropertyFlagBits::eHostCached);
1586+
ctx->sync_staging = ggml_vk_create_buffer_check(ctx, size,
1587+
vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent | vk::MemoryPropertyFlagBits::eHostCached,
1588+
vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent);
15721589
}
15731590
}
15741591

@@ -4082,7 +4099,9 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx) {
40824099
std::cerr << "ggml_vk_preallocate_buffers(qx_size: " << ctx->prealloc_size_qx << " qy_size: " << ctx->prealloc_size_qy << " x_size: " << ctx->prealloc_size_x << " y_size: " << ctx->prealloc_size_y << " split_k_size: " << ctx->prealloc_size_split_k << ")" << std::endl;
40834100
#endif
40844101
#if defined(GGML_VULKAN_RUN_TESTS)
4085-
ctx->staging = ggml_vk_create_buffer_check(ctx, 100ul * 1024ul * 1024ul, vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent | vk::MemoryPropertyFlagBits::eHostCached);
4102+
ctx->staging = ggml_vk_create_buffer_check(ctx, 100ul * 1024ul * 1024ul,
4103+
vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent | vk::MemoryPropertyFlagBits::eHostCached
4104+
vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent);
40864105
ggml_vk_test_transfer(ctx, 8192 * 1000, false);
40874106
ggml_vk_test_transfer(ctx, 8192 * 1000, true);
40884107

@@ -4174,7 +4193,9 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx) {
41744193
if (ctx->staging != nullptr) {
41754194
ggml_vk_destroy_buffer(ctx->staging);
41764195
}
4177-
ctx->staging = ggml_vk_create_buffer_check(ctx, ctx->staging_size, vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent | vk::MemoryPropertyFlagBits::eHostCached);
4196+
ctx->staging = ggml_vk_create_buffer_check(ctx, ctx->staging_size,
4197+
vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent | vk::MemoryPropertyFlagBits::eHostCached,
4198+
vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent);
41784199
}
41794200
}
41804201

0 commit comments

Comments
 (0)