@@ -707,9 +707,21 @@ static void ggml_vk_queue_cleanup(ggml_backend_vk_context * ctx, vk_queue& q) {
707
707
q.cmd_buffer_idx = 0 ;
708
708
}
709
709
710
- static vk_buffer ggml_vk_create_buffer (ggml_backend_vk_context * ctx, size_t size, vk::MemoryPropertyFlags req_flags) {
710
+ static int32_t find_properties (const vk::PhysicalDeviceMemoryProperties* mem_props, vk::MemoryRequirements* mem_req, vk::MemoryPropertyFlags flags) {
711
+ for (uint32_t i = 0 ; i < mem_props->memoryTypeCount ; ++i) {
712
+ vk::MemoryType memory_type = mem_props->memoryTypes [i];
713
+ if ((mem_req->memoryTypeBits & ((uint64_t )1 << i)) &&
714
+ (flags & memory_type.propertyFlags ) == flags &&
715
+ mem_props->memoryHeaps [memory_type.heapIndex ].size >= mem_req->size ) {
716
+ return static_cast <int32_t >(i);
717
+ }
718
+ }
719
+ return -1 ;
720
+ }
721
+
722
+ static vk_buffer ggml_vk_create_buffer (ggml_backend_vk_context * ctx, size_t size, vk::MemoryPropertyFlags req_flags, vk::MemoryPropertyFlags fallback_flags = vk::MemoryPropertyFlags(0 )) {
711
723
#ifdef GGML_VULKAN_DEBUG
712
- std::cerr << " ggml_vk_create_buffer(" << size << " , " << to_string (req_flags) << " )" << std::endl;
724
+ std::cerr << " ggml_vk_create_buffer(" << size << " , " << to_string (req_flags) << " , " << to_string (fallback_flags) << " )" << std::endl;
713
725
#endif
714
726
vk_buffer buf = std::make_shared<vk_buffer_struct>();
715
727
@@ -734,17 +746,22 @@ static vk_buffer ggml_vk_create_buffer(ggml_backend_vk_context * ctx, size_t siz
734
746
735
747
vk::PhysicalDeviceMemoryProperties mem_props = ctx->device .lock ()->physical_device .getMemoryProperties ();
736
748
737
- uint32_t memory_type_index = UINT32_MAX ;
749
+ uint32_t memory_type_index = - 1 ;
738
750
739
- for (uint32_t i = 0 ; i < mem_props.memoryTypeCount ; ++i) {
740
- vk::MemoryType memory_type = mem_props.memoryTypes [i];
741
- if ((mem_req.memoryTypeBits & ((uint64_t )1 << i)) && (req_flags & memory_type.propertyFlags ) == req_flags && mem_props.memoryHeaps [memory_type.heapIndex ].size >= mem_req.size ) {
742
- memory_type_index = i;
743
- break ;
744
- }
751
+ memory_type_index = find_properties (&mem_props, &mem_req, req_flags);
752
+ buf->memory_property_flags = req_flags;
753
+
754
+ // Failed to find memory type matching req_flags, but we can try again with fallback_flags if specified...
755
+ if (memory_type_index == -1 && fallback_flags && (
756
+ // ...as long as req_flags was either: 1) not DEVICE_LOCAL; or 2) DEVICE_LOCAL and device has UMA.
757
+ !(req_flags & vk::MemoryPropertyFlagBits::eDeviceLocal) ||
758
+ (req_flags & vk::MemoryPropertyFlagBits::eDeviceLocal && ctx->device .lock ()->uma ))
759
+ ) {
760
+ memory_type_index = find_properties (&mem_props, &mem_req, fallback_flags);
761
+ buf->memory_property_flags = fallback_flags;
745
762
}
746
763
747
- if (memory_type_index >= mem_props. memoryTypeCount ) {
764
+ if (memory_type_index == - 1 ) {
748
765
ctx->device .lock ()->device .destroyBuffer (buf->buffer );
749
766
buf->size = 0 ;
750
767
throw vk::OutOfDeviceMemoryError (" No suitable memory type found" );
@@ -758,10 +775,9 @@ static vk_buffer ggml_vk_create_buffer(ggml_backend_vk_context * ctx, size_t siz
758
775
buf->size = 0 ;
759
776
throw e;
760
777
}
761
- buf->memory_property_flags = req_flags;
762
778
buf->ptr = nullptr ;
763
779
764
- if (req_flags & vk::MemoryPropertyFlagBits::eHostVisible) {
780
+ if (buf-> memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible) {
765
781
buf->ptr = ctx->device .lock ()->device .mapMemory (buf->device_memory , 0 , VK_WHOLE_SIZE);
766
782
}
767
783
@@ -778,9 +794,9 @@ static vk_buffer ggml_vk_create_buffer(ggml_backend_vk_context * ctx, size_t siz
778
794
return buf;
779
795
}
780
796
781
- static vk_buffer ggml_vk_create_buffer_check (ggml_backend_vk_context * ctx, size_t size, vk::MemoryPropertyFlags req_flags) {
797
+ static vk_buffer ggml_vk_create_buffer_check (ggml_backend_vk_context * ctx, size_t size, vk::MemoryPropertyFlags req_flags, vk::MemoryPropertyFlags fallback_flags = vk::MemoryPropertyFlags( 0 ) ) {
782
798
try {
783
- return ggml_vk_create_buffer (ctx, size, req_flags);
799
+ return ggml_vk_create_buffer (ctx, size, req_flags, fallback_flags );
784
800
} catch (const vk::SystemError& e) {
785
801
std::cerr << " ggml_vulkan: Memory allocation of size " << size << " failed." << std::endl;
786
802
std::cerr << " ggml_vulkan: " << e.what () << std::endl;
@@ -791,16 +807,13 @@ static vk_buffer ggml_vk_create_buffer_check(ggml_backend_vk_context * ctx, size
791
807
static vk_buffer ggml_vk_create_buffer_device (ggml_backend_vk_context * ctx, size_t size) {
792
808
vk_buffer buf;
793
809
try {
794
- buf = ggml_vk_create_buffer (ctx, size, vk::MemoryPropertyFlagBits::eDeviceLocal);
810
+ buf = ggml_vk_create_buffer (ctx, size,
811
+ vk::MemoryPropertyFlagBits::eDeviceLocal,
812
+ vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent);
795
813
} catch (const vk::SystemError& e) {
796
- if (ctx->device .lock ()->uma ) {
797
- // Fall back to host memory type
798
- buf = ggml_vk_create_buffer_check (ctx, size, vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent);
799
- } else {
800
- std::cerr << " ggml_vulkan: Device memory allocation of size " << size << " failed." << std::endl;
801
- std::cerr << " ggml_vulkan: " << e.what () << std::endl;
802
- throw e;
803
- }
814
+ std::cerr << " ggml_vulkan: Device memory allocation of size " << size << " failed." << std::endl;
815
+ std::cerr << " ggml_vulkan: " << e.what () << std::endl;
816
+ throw e;
804
817
}
805
818
806
819
return buf;
@@ -1422,7 +1435,9 @@ static void * ggml_vk_host_malloc(ggml_backend_vk_context * ctx, size_t size) {
1422
1435
#ifdef GGML_VULKAN_DEBUG
1423
1436
std::cerr << " ggml_vk_host_malloc(" << size << " )" << std::endl;
1424
1437
#endif
1425
- vk_buffer buf = ggml_vk_create_buffer (ctx, size, vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent | vk::MemoryPropertyFlagBits::eHostCached);
1438
+ vk_buffer buf = ggml_vk_create_buffer (ctx, size,
1439
+ vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent | vk::MemoryPropertyFlagBits::eHostCached,
1440
+ vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent);
1426
1441
1427
1442
if (!(buf->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible)) {
1428
1443
fprintf (stderr, " WARNING: failed to allocate %.2f MB of pinned memory\n " ,
@@ -1568,7 +1583,9 @@ static void deferred_memcpy(void * dst, const void * src, size_t size, std::vect
1568
1583
static void ggml_vk_ensure_sync_staging_buffer (ggml_backend_vk_context * ctx, size_t size) {
1569
1584
if (ctx->sync_staging == nullptr || ctx->sync_staging ->size < size) {
1570
1585
ggml_vk_destroy_buffer (ctx->sync_staging );
1571
- ctx->sync_staging = ggml_vk_create_buffer_check (ctx, size, vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent | vk::MemoryPropertyFlagBits::eHostCached);
1586
+ ctx->sync_staging = ggml_vk_create_buffer_check (ctx, size,
1587
+ vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent | vk::MemoryPropertyFlagBits::eHostCached,
1588
+ vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent);
1572
1589
}
1573
1590
}
1574
1591
@@ -4082,7 +4099,9 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx) {
4082
4099
std::cerr << " ggml_vk_preallocate_buffers(qx_size: " << ctx->prealloc_size_qx << " qy_size: " << ctx->prealloc_size_qy << " x_size: " << ctx->prealloc_size_x << " y_size: " << ctx->prealloc_size_y << " split_k_size: " << ctx->prealloc_size_split_k << " )" << std::endl;
4083
4100
#endif
4084
4101
#if defined(GGML_VULKAN_RUN_TESTS)
4085
- ctx->staging = ggml_vk_create_buffer_check (ctx, 100ul * 1024ul * 1024ul , vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent | vk::MemoryPropertyFlagBits::eHostCached);
4102
+ ctx->staging = ggml_vk_create_buffer_check (ctx, 100ul * 1024ul * 1024ul ,
4103
+ vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent | vk::MemoryPropertyFlagBits::eHostCached
4104
+ vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent);
4086
4105
ggml_vk_test_transfer (ctx, 8192 * 1000 , false );
4087
4106
ggml_vk_test_transfer (ctx, 8192 * 1000 , true );
4088
4107
@@ -4174,7 +4193,9 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx) {
4174
4193
if (ctx->staging != nullptr ) {
4175
4194
ggml_vk_destroy_buffer (ctx->staging );
4176
4195
}
4177
- ctx->staging = ggml_vk_create_buffer_check (ctx, ctx->staging_size , vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent | vk::MemoryPropertyFlagBits::eHostCached);
4196
+ ctx->staging = ggml_vk_create_buffer_check (ctx, ctx->staging_size ,
4197
+ vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent | vk::MemoryPropertyFlagBits::eHostCached,
4198
+ vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent);
4178
4199
}
4179
4200
}
4180
4201
0 commit comments