@@ -706,9 +706,21 @@ static void ggml_vk_queue_cleanup(ggml_backend_vk_context * ctx, vk_queue& q) {
706
706
q.cmd_buffer_idx = 0 ;
707
707
}
708
708
709
- static vk_buffer ggml_vk_create_buffer (ggml_backend_vk_context * ctx, size_t size, vk::MemoryPropertyFlags req_flags) {
709
+ static int32_t find_properties (const vk::PhysicalDeviceMemoryProperties* mem_props, vk::MemoryRequirements* mem_req, vk::MemoryPropertyFlags flags) {
710
+ for (uint32_t i = 0 ; i < mem_props->memoryTypeCount ; ++i) {
711
+ vk::MemoryType memory_type = mem_props->memoryTypes [i];
712
+ if ((mem_req->memoryTypeBits & ((uint64_t )1 << i)) &&
713
+ (flags & memory_type.propertyFlags ) == flags &&
714
+ mem_props->memoryHeaps [memory_type.heapIndex ].size >= mem_req->size ) {
715
+ return static_cast <int32_t >(i);
716
+ }
717
+ }
718
+ return -1 ;
719
+ }
720
+
721
+ static vk_buffer ggml_vk_create_buffer (ggml_backend_vk_context * ctx, size_t size, vk::MemoryPropertyFlags req_flags, vk::MemoryPropertyFlags desired_flags = vk::MemoryPropertyFlags(0 )) {
710
722
#ifdef GGML_VULKAN_DEBUG
711
- std::cerr << " ggml_vk_create_buffer(" << size << " , " << to_string (req_flags) << " )" << std::endl;
723
+ std::cerr << " ggml_vk_create_buffer(" << size << " , " << to_string (req_flags) << " , " << to_string (desired_flags) << " )" << std::endl;
712
724
#endif
713
725
vk_buffer buf = std::make_shared<vk_buffer_struct>();
714
726
@@ -733,17 +745,15 @@ static vk_buffer ggml_vk_create_buffer(ggml_backend_vk_context * ctx, size_t siz
733
745
734
746
vk::PhysicalDeviceMemoryProperties mem_props = ctx->device .lock ()->physical_device .getMemoryProperties ();
735
747
736
- uint32_t memory_type_index = UINT32_MAX;
737
-
738
- for (uint32_t i = 0 ; i < mem_props.memoryTypeCount ; ++i) {
739
- vk::MemoryType memory_type = mem_props.memoryTypes [i];
740
- if ((mem_req.memoryTypeBits & ((uint64_t )1 << i)) && (req_flags & memory_type.propertyFlags ) == req_flags && mem_props.memoryHeaps [memory_type.heapIndex ].size >= mem_req.size ) {
741
- memory_type_index = i;
742
- break ;
743
- }
748
+ uint32_t memory_type_index = -1 ;
749
+ if (desired_flags) {
750
+ memory_type_index = find_properties (&mem_props, &mem_req, req_flags | desired_flags);
751
+ }
752
+ if (memory_type_index == -1 ) {
753
+ memory_type_index = find_properties (&mem_props, &mem_req, req_flags);
744
754
}
745
755
746
- if (memory_type_index >= mem_props. memoryTypeCount ) {
756
+ if (memory_type_index == - 1 ) {
747
757
throw vk::OutOfDeviceMemoryError (" No suitable memory type found" );
748
758
}
749
759
@@ -775,7 +785,7 @@ static vk_buffer ggml_vk_create_buffer(ggml_backend_vk_context * ctx, size_t siz
775
785
return buf;
776
786
}
777
787
778
- static vk_buffer ggml_vk_create_buffer_check (ggml_backend_vk_context * ctx, size_t size, vk::MemoryPropertyFlags req_flags) {
788
+ static vk_buffer ggml_vk_create_buffer_check (ggml_backend_vk_context * ctx, size_t size, vk::MemoryPropertyFlags req_flags, vk::MemoryPropertyFlags desired_flags = vk::MemoryPropertyFlags( 0 ) ) {
779
789
try {
780
790
return ggml_vk_create_buffer (ctx, size, req_flags);
781
791
} catch (const vk::SystemError& e) {
@@ -1419,7 +1429,7 @@ static void * ggml_vk_host_malloc(ggml_backend_vk_context * ctx, size_t size) {
1419
1429
#ifdef GGML_VULKAN_DEBUG
1420
1430
std::cerr << " ggml_vk_host_malloc(" << size << " )" << std::endl;
1421
1431
#endif
1422
- vk_buffer buf = ggml_vk_create_buffer (ctx, size, vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent | vk::MemoryPropertyFlagBits::eHostCached);
1432
+ vk_buffer buf = ggml_vk_create_buffer (ctx, size, vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent, vk::MemoryPropertyFlagBits::eHostCached);
1423
1433
1424
1434
if (!(buf->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible)) {
1425
1435
fprintf (stderr, " WARNING: failed to allocate %.2f MB of pinned memory\n " ,
@@ -1565,7 +1575,7 @@ static void deferred_memcpy(void * dst, const void * src, size_t size, std::vect
1565
1575
static void ggml_vk_ensure_sync_staging_buffer (ggml_backend_vk_context * ctx, size_t size) {
1566
1576
if (ctx->sync_staging == nullptr || ctx->sync_staging ->size < size) {
1567
1577
ggml_vk_destroy_buffer (ctx->sync_staging );
1568
- ctx->sync_staging = ggml_vk_create_buffer_check (ctx, size, vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent | vk::MemoryPropertyFlagBits::eHostCached);
1578
+ ctx->sync_staging = ggml_vk_create_buffer_check (ctx, size, vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent, vk::MemoryPropertyFlagBits::eHostCached);
1569
1579
}
1570
1580
}
1571
1581
@@ -3998,7 +4008,7 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx) {
3998
4008
std::cerr << " qx_size: " << ctx->prealloc_size_qx << " qy_size: " << ctx->prealloc_size_qy << " x_size: " << ctx->prealloc_size_x << " y_size: " << ctx->prealloc_size_y << " split_k_size: " << ctx->prealloc_size_split_k << std::endl;
3999
4009
#endif
4000
4010
#if defined(GGML_VULKAN_RUN_TESTS)
4001
- ctx->staging = ggml_vk_create_buffer_check (ctx, 100ul * 1024ul * 1024ul , vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent | vk::MemoryPropertyFlagBits::eHostCached);
4011
+ ctx->staging = ggml_vk_create_buffer_check (ctx, 100ul * 1024ul * 1024ul , vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent, vk::MemoryPropertyFlagBits::eHostCached);
4002
4012
ggml_vk_test_transfer (ctx, 8192 * 1000 , false );
4003
4013
ggml_vk_test_transfer (ctx, 8192 * 1000 , true );
4004
4014
@@ -4090,7 +4100,7 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx) {
4090
4100
if (ctx->staging != nullptr ) {
4091
4101
ggml_vk_destroy_buffer (ctx->staging );
4092
4102
}
4093
- ctx->staging = ggml_vk_create_buffer_check (ctx, ctx->staging_size , vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent | vk::MemoryPropertyFlagBits::eHostCached);
4103
+ ctx->staging = ggml_vk_create_buffer_check (ctx, ctx->staging_size , vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent, vk::MemoryPropertyFlagBits::eHostCached);
4094
4104
}
4095
4105
}
4096
4106
0 commit comments