Skip to content

Commit 4b024ac

Browse files
luciferoushodlen
authored andcommitted
vulkan: Find optimal memory type but with fallback (ggml-org#5381)
* @0cc4m feedback * More feedback @0cc4m
1 parent d2e7cd0 commit 4b024ac

File tree

1 file changed

+42
-23
lines changed

1 file changed

+42
-23
lines changed

ggml-vulkan.cpp

Lines changed: 42 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -707,9 +707,21 @@ static void ggml_vk_queue_cleanup(ggml_backend_vk_context * ctx, vk_queue& q) {
707707
q.cmd_buffer_idx = 0;
708708
}
709709

710-
static vk_buffer ggml_vk_create_buffer(ggml_backend_vk_context * ctx, size_t size, vk::MemoryPropertyFlags req_flags) {
710+
static uint32_t find_properties(const vk::PhysicalDeviceMemoryProperties* mem_props, vk::MemoryRequirements* mem_req, vk::MemoryPropertyFlags flags) {
711+
for (uint32_t i = 0; i < mem_props->memoryTypeCount; ++i) {
712+
vk::MemoryType memory_type = mem_props->memoryTypes[i];
713+
if ((mem_req->memoryTypeBits & ((uint64_t)1 << i)) &&
714+
(flags & memory_type.propertyFlags) == flags &&
715+
mem_props->memoryHeaps[memory_type.heapIndex].size >= mem_req->size) {
716+
return static_cast<int32_t>(i);
717+
}
718+
}
719+
return UINT32_MAX;
720+
}
721+
722+
static vk_buffer ggml_vk_create_buffer(ggml_backend_vk_context * ctx, size_t size, vk::MemoryPropertyFlags req_flags, vk::MemoryPropertyFlags fallback_flags = vk::MemoryPropertyFlags(0)) {
711723
#ifdef GGML_VULKAN_DEBUG
712-
std::cerr << "ggml_vk_create_buffer(" << size << ", " << to_string(req_flags) << ")" << std::endl;
724+
std::cerr << "ggml_vk_create_buffer(" << size << ", " << to_string(req_flags) << ", " << to_string(fallback_flags) << ")" << std::endl;
713725
#endif
714726
vk_buffer buf = std::make_shared<vk_buffer_struct>();
715727

@@ -736,15 +748,15 @@ static vk_buffer ggml_vk_create_buffer(ggml_backend_vk_context * ctx, size_t siz
736748

737749
uint32_t memory_type_index = UINT32_MAX;
738750

739-
for (uint32_t i = 0; i < mem_props.memoryTypeCount; ++i) {
740-
vk::MemoryType memory_type = mem_props.memoryTypes[i];
741-
if ((mem_req.memoryTypeBits & ((uint64_t)1 << i)) && (req_flags & memory_type.propertyFlags) == req_flags && mem_props.memoryHeaps[memory_type.heapIndex].size >= mem_req.size) {
742-
memory_type_index = i;
743-
break;
744-
}
751+
memory_type_index = find_properties(&mem_props, &mem_req, req_flags);
752+
buf->memory_property_flags = req_flags;
753+
754+
if (memory_type_index == UINT32_MAX && fallback_flags) {
755+
memory_type_index = find_properties(&mem_props, &mem_req, fallback_flags);
756+
buf->memory_property_flags = fallback_flags;
745757
}
746758

747-
if (memory_type_index >= mem_props.memoryTypeCount) {
759+
if (memory_type_index == UINT32_MAX) {
748760
ctx->device.lock()->device.destroyBuffer(buf->buffer);
749761
buf->size = 0;
750762
throw vk::OutOfDeviceMemoryError("No suitable memory type found");
@@ -758,10 +770,9 @@ static vk_buffer ggml_vk_create_buffer(ggml_backend_vk_context * ctx, size_t siz
758770
buf->size = 0;
759771
throw e;
760772
}
761-
buf->memory_property_flags = req_flags;
762773
buf->ptr = nullptr;
763774

764-
if (req_flags & vk::MemoryPropertyFlagBits::eHostVisible) {
775+
if (buf->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible) {
765776
buf->ptr = ctx->device.lock()->device.mapMemory(buf->device_memory, 0, VK_WHOLE_SIZE);
766777
}
767778

@@ -778,9 +789,9 @@ static vk_buffer ggml_vk_create_buffer(ggml_backend_vk_context * ctx, size_t siz
778789
return buf;
779790
}
780791

781-
static vk_buffer ggml_vk_create_buffer_check(ggml_backend_vk_context * ctx, size_t size, vk::MemoryPropertyFlags req_flags) {
792+
static vk_buffer ggml_vk_create_buffer_check(ggml_backend_vk_context * ctx, size_t size, vk::MemoryPropertyFlags req_flags, vk::MemoryPropertyFlags fallback_flags = vk::MemoryPropertyFlags(0)) {
782793
try {
783-
return ggml_vk_create_buffer(ctx, size, req_flags);
794+
return ggml_vk_create_buffer(ctx, size, req_flags, fallback_flags);
784795
} catch (const vk::SystemError& e) {
785796
std::cerr << "ggml_vulkan: Memory allocation of size " << size << " failed." << std::endl;
786797
std::cerr << "ggml_vulkan: " << e.what() << std::endl;
@@ -791,16 +802,16 @@ static vk_buffer ggml_vk_create_buffer_check(ggml_backend_vk_context * ctx, size
791802
static vk_buffer ggml_vk_create_buffer_device(ggml_backend_vk_context * ctx, size_t size) {
792803
vk_buffer buf;
793804
try {
794-
buf = ggml_vk_create_buffer(ctx, size, vk::MemoryPropertyFlagBits::eDeviceLocal);
795-
} catch (const vk::SystemError& e) {
796805
if (ctx->device.lock()->uma) {
797806
// Fall back to host memory type
798-
buf = ggml_vk_create_buffer_check(ctx, size, vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent);
807+
buf = ggml_vk_create_buffer(ctx, size, vk::MemoryPropertyFlagBits::eDeviceLocal, vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent);
799808
} else {
800-
std::cerr << "ggml_vulkan: Device memory allocation of size " << size << " failed." << std::endl;
801-
std::cerr << "ggml_vulkan: " << e.what() << std::endl;
802-
throw e;
809+
buf = ggml_vk_create_buffer(ctx, size, vk::MemoryPropertyFlagBits::eDeviceLocal);
803810
}
811+
} catch (const vk::SystemError& e) {
812+
std::cerr << "ggml_vulkan: Device memory allocation of size " << size << " failed." << std::endl;
813+
std::cerr << "ggml_vulkan: " << e.what() << std::endl;
814+
throw e;
804815
}
805816

806817
return buf;
@@ -1422,7 +1433,9 @@ static void * ggml_vk_host_malloc(ggml_backend_vk_context * ctx, size_t size) {
14221433
#ifdef GGML_VULKAN_DEBUG
14231434
std::cerr << "ggml_vk_host_malloc(" << size << ")" << std::endl;
14241435
#endif
1425-
vk_buffer buf = ggml_vk_create_buffer(ctx, size, vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent | vk::MemoryPropertyFlagBits::eHostCached);
1436+
vk_buffer buf = ggml_vk_create_buffer(ctx, size,
1437+
vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent | vk::MemoryPropertyFlagBits::eHostCached,
1438+
vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent);
14261439

14271440
if(!(buf->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible)) {
14281441
fprintf(stderr, "WARNING: failed to allocate %.2f MB of pinned memory\n",
@@ -1568,7 +1581,9 @@ static void deferred_memcpy(void * dst, const void * src, size_t size, std::vect
15681581
static void ggml_vk_ensure_sync_staging_buffer(ggml_backend_vk_context * ctx, size_t size) {
15691582
if (ctx->sync_staging == nullptr || ctx->sync_staging->size < size) {
15701583
ggml_vk_destroy_buffer(ctx->sync_staging);
1571-
ctx->sync_staging = ggml_vk_create_buffer_check(ctx, size, vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent | vk::MemoryPropertyFlagBits::eHostCached);
1584+
ctx->sync_staging = ggml_vk_create_buffer_check(ctx, size,
1585+
vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent | vk::MemoryPropertyFlagBits::eHostCached,
1586+
vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent);
15721587
}
15731588
}
15741589

@@ -4082,7 +4097,9 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx) {
40824097
std::cerr << "ggml_vk_preallocate_buffers(qx_size: " << ctx->prealloc_size_qx << " qy_size: " << ctx->prealloc_size_qy << " x_size: " << ctx->prealloc_size_x << " y_size: " << ctx->prealloc_size_y << " split_k_size: " << ctx->prealloc_size_split_k << ")" << std::endl;
40834098
#endif
40844099
#if defined(GGML_VULKAN_RUN_TESTS)
4085-
ctx->staging = ggml_vk_create_buffer_check(ctx, 100ul * 1024ul * 1024ul, vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent | vk::MemoryPropertyFlagBits::eHostCached);
4100+
ctx->staging = ggml_vk_create_buffer_check(ctx, 100ul * 1024ul * 1024ul,
4101+
vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent | vk::MemoryPropertyFlagBits::eHostCached
4102+
vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent);
40864103
ggml_vk_test_transfer(ctx, 8192 * 1000, false);
40874104
ggml_vk_test_transfer(ctx, 8192 * 1000, true);
40884105

@@ -4174,7 +4191,9 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx) {
41744191
if (ctx->staging != nullptr) {
41754192
ggml_vk_destroy_buffer(ctx->staging);
41764193
}
4177-
ctx->staging = ggml_vk_create_buffer_check(ctx, ctx->staging_size, vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent | vk::MemoryPropertyFlagBits::eHostCached);
4194+
ctx->staging = ggml_vk_create_buffer_check(ctx, ctx->staging_size,
4195+
vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent | vk::MemoryPropertyFlagBits::eHostCached,
4196+
vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent);
41784197
}
41794198
}
41804199

0 commit comments

Comments
 (0)