@@ -707,9 +707,21 @@ static void ggml_vk_queue_cleanup(ggml_backend_vk_context * ctx, vk_queue& q) {
707
707
q.cmd_buffer_idx = 0 ;
708
708
}
709
709
710
- static vk_buffer ggml_vk_create_buffer (ggml_backend_vk_context * ctx, size_t size, vk::MemoryPropertyFlags req_flags) {
710
+ static uint32_t find_properties (const vk::PhysicalDeviceMemoryProperties* mem_props, vk::MemoryRequirements* mem_req, vk::MemoryPropertyFlags flags) {
711
+ for (uint32_t i = 0 ; i < mem_props->memoryTypeCount ; ++i) {
712
+ vk::MemoryType memory_type = mem_props->memoryTypes [i];
713
+ if ((mem_req->memoryTypeBits & ((uint64_t )1 << i)) &&
714
+ (flags & memory_type.propertyFlags ) == flags &&
715
+ mem_props->memoryHeaps [memory_type.heapIndex ].size >= mem_req->size ) {
716
+ return static_cast <int32_t >(i);
717
+ }
718
+ }
719
+ return UINT32_MAX;
720
+ }
721
+
722
+ static vk_buffer ggml_vk_create_buffer (ggml_backend_vk_context * ctx, size_t size, vk::MemoryPropertyFlags req_flags, vk::MemoryPropertyFlags fallback_flags = vk::MemoryPropertyFlags(0 )) {
711
723
#ifdef GGML_VULKAN_DEBUG
712
- std::cerr << " ggml_vk_create_buffer(" << size << " , " << to_string (req_flags) << " )" << std::endl;
724
+ std::cerr << " ggml_vk_create_buffer(" << size << " , " << to_string (req_flags) << " , " << to_string (fallback_flags) << " )" << std::endl;
713
725
#endif
714
726
vk_buffer buf = std::make_shared<vk_buffer_struct>();
715
727
@@ -736,15 +748,15 @@ static vk_buffer ggml_vk_create_buffer(ggml_backend_vk_context * ctx, size_t siz
736
748
737
749
uint32_t memory_type_index = UINT32_MAX;
738
750
739
- for ( uint32_t i = 0 ; i < mem_props. memoryTypeCount ; ++i) {
740
- vk::MemoryType memory_type = mem_props. memoryTypes [i] ;
741
- if ((mem_req. memoryTypeBits & (( uint64_t ) 1 << i)) && (req_flags & memory_type. propertyFlags ) == req_flags && mem_props. memoryHeaps [memory_type. heapIndex ]. size >= mem_req. size ) {
742
- memory_type_index = i;
743
- break ;
744
- }
751
+ memory_type_index = find_properties (& mem_props, &mem_req, req_flags);
752
+ buf-> memory_property_flags = req_flags ;
753
+
754
+ if (memory_type_index == UINT32_MAX && fallback_flags) {
755
+ memory_type_index = find_properties (&mem_props, &mem_req, fallback_flags) ;
756
+ buf-> memory_property_flags = fallback_flags;
745
757
}
746
758
747
- if (memory_type_index >= mem_props. memoryTypeCount ) {
759
+ if (memory_type_index == UINT32_MAX ) {
748
760
ctx->device .lock ()->device .destroyBuffer (buf->buffer );
749
761
buf->size = 0 ;
750
762
throw vk::OutOfDeviceMemoryError (" No suitable memory type found" );
@@ -758,10 +770,9 @@ static vk_buffer ggml_vk_create_buffer(ggml_backend_vk_context * ctx, size_t siz
758
770
buf->size = 0 ;
759
771
throw e;
760
772
}
761
- buf->memory_property_flags = req_flags;
762
773
buf->ptr = nullptr ;
763
774
764
- if (req_flags & vk::MemoryPropertyFlagBits::eHostVisible) {
775
+ if (buf-> memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible) {
765
776
buf->ptr = ctx->device .lock ()->device .mapMemory (buf->device_memory , 0 , VK_WHOLE_SIZE);
766
777
}
767
778
@@ -778,9 +789,9 @@ static vk_buffer ggml_vk_create_buffer(ggml_backend_vk_context * ctx, size_t siz
778
789
return buf;
779
790
}
780
791
781
- static vk_buffer ggml_vk_create_buffer_check (ggml_backend_vk_context * ctx, size_t size, vk::MemoryPropertyFlags req_flags) {
792
+ static vk_buffer ggml_vk_create_buffer_check (ggml_backend_vk_context * ctx, size_t size, vk::MemoryPropertyFlags req_flags, vk::MemoryPropertyFlags fallback_flags = vk::MemoryPropertyFlags( 0 ) ) {
782
793
try {
783
- return ggml_vk_create_buffer (ctx, size, req_flags);
794
+ return ggml_vk_create_buffer (ctx, size, req_flags, fallback_flags );
784
795
} catch (const vk::SystemError& e) {
785
796
std::cerr << " ggml_vulkan: Memory allocation of size " << size << " failed." << std::endl;
786
797
std::cerr << " ggml_vulkan: " << e.what () << std::endl;
@@ -791,16 +802,16 @@ static vk_buffer ggml_vk_create_buffer_check(ggml_backend_vk_context * ctx, size
791
802
static vk_buffer ggml_vk_create_buffer_device (ggml_backend_vk_context * ctx, size_t size) {
792
803
vk_buffer buf;
793
804
try {
794
- buf = ggml_vk_create_buffer (ctx, size, vk::MemoryPropertyFlagBits::eDeviceLocal);
795
- } catch (const vk::SystemError& e) {
796
805
if (ctx->device .lock ()->uma ) {
797
806
// Fall back to host memory type
798
- buf = ggml_vk_create_buffer_check (ctx, size, vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent);
807
+ buf = ggml_vk_create_buffer (ctx, size, vk::MemoryPropertyFlagBits::eDeviceLocal , vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent);
799
808
} else {
800
- std::cerr << " ggml_vulkan: Device memory allocation of size " << size << " failed." << std::endl;
801
- std::cerr << " ggml_vulkan: " << e.what () << std::endl;
802
- throw e;
809
+ buf = ggml_vk_create_buffer (ctx, size, vk::MemoryPropertyFlagBits::eDeviceLocal);
803
810
}
811
+ } catch (const vk::SystemError& e) {
812
+ std::cerr << " ggml_vulkan: Device memory allocation of size " << size << " failed." << std::endl;
813
+ std::cerr << " ggml_vulkan: " << e.what () << std::endl;
814
+ throw e;
804
815
}
805
816
806
817
return buf;
@@ -1422,7 +1433,9 @@ static void * ggml_vk_host_malloc(ggml_backend_vk_context * ctx, size_t size) {
1422
1433
#ifdef GGML_VULKAN_DEBUG
1423
1434
std::cerr << " ggml_vk_host_malloc(" << size << " )" << std::endl;
1424
1435
#endif
1425
- vk_buffer buf = ggml_vk_create_buffer (ctx, size, vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent | vk::MemoryPropertyFlagBits::eHostCached);
1436
+ vk_buffer buf = ggml_vk_create_buffer (ctx, size,
1437
+ vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent | vk::MemoryPropertyFlagBits::eHostCached,
1438
+ vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent);
1426
1439
1427
1440
if (!(buf->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible)) {
1428
1441
fprintf (stderr, " WARNING: failed to allocate %.2f MB of pinned memory\n " ,
@@ -1568,7 +1581,9 @@ static void deferred_memcpy(void * dst, const void * src, size_t size, std::vect
1568
1581
static void ggml_vk_ensure_sync_staging_buffer (ggml_backend_vk_context * ctx, size_t size) {
1569
1582
if (ctx->sync_staging == nullptr || ctx->sync_staging ->size < size) {
1570
1583
ggml_vk_destroy_buffer (ctx->sync_staging );
1571
- ctx->sync_staging = ggml_vk_create_buffer_check (ctx, size, vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent | vk::MemoryPropertyFlagBits::eHostCached);
1584
+ ctx->sync_staging = ggml_vk_create_buffer_check (ctx, size,
1585
+ vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent | vk::MemoryPropertyFlagBits::eHostCached,
1586
+ vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent);
1572
1587
}
1573
1588
}
1574
1589
@@ -4082,7 +4097,9 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx) {
4082
4097
std::cerr << " ggml_vk_preallocate_buffers(qx_size: " << ctx->prealloc_size_qx << " qy_size: " << ctx->prealloc_size_qy << " x_size: " << ctx->prealloc_size_x << " y_size: " << ctx->prealloc_size_y << " split_k_size: " << ctx->prealloc_size_split_k << " )" << std::endl;
4083
4098
#endif
4084
4099
#if defined(GGML_VULKAN_RUN_TESTS)
4085
- ctx->staging = ggml_vk_create_buffer_check (ctx, 100ul * 1024ul * 1024ul , vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent | vk::MemoryPropertyFlagBits::eHostCached);
4100
+ ctx->staging = ggml_vk_create_buffer_check (ctx, 100ul * 1024ul * 1024ul ,
4101
+ vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent | vk::MemoryPropertyFlagBits::eHostCached
4102
+ vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent);
4086
4103
ggml_vk_test_transfer (ctx, 8192 * 1000 , false );
4087
4104
ggml_vk_test_transfer (ctx, 8192 * 1000 , true );
4088
4105
@@ -4174,7 +4191,9 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx) {
4174
4191
if (ctx->staging != nullptr ) {
4175
4192
ggml_vk_destroy_buffer (ctx->staging );
4176
4193
}
4177
- ctx->staging = ggml_vk_create_buffer_check (ctx, ctx->staging_size , vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent | vk::MemoryPropertyFlagBits::eHostCached);
4194
+ ctx->staging = ggml_vk_create_buffer_check (ctx, ctx->staging_size ,
4195
+ vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent | vk::MemoryPropertyFlagBits::eHostCached,
4196
+ vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent);
4178
4197
}
4179
4198
}
4180
4199
0 commit comments