@@ -652,9 +652,21 @@ static void ggml_vk_queue_cleanup(vk_queue& q) {
652
652
q.cmd_buffer_idx = 0 ;
653
653
}
654
654
655
- static vk_buffer ggml_vk_create_buffer (size_t size, vk::MemoryPropertyFlags req_flags) {
655
+ static int32_t find_properties (const vk::PhysicalDeviceMemoryProperties* mem_props, vk::MemoryRequirements* mem_req, vk::MemoryPropertyFlags flags) {
656
+ for (uint32_t i = 0 ; i < mem_props->memoryTypeCount ; ++i) {
657
+ vk::MemoryType memory_type = mem_props->memoryTypes [i];
658
+ if ((mem_req->memoryTypeBits & ((uint64_t )1 << i)) &&
659
+ (flags & memory_type.propertyFlags ) == flags &&
660
+ mem_props->memoryHeaps [memory_type.heapIndex ].size >= mem_req->size ) {
661
+ return static_cast <int32_t >(i);
662
+ }
663
+ }
664
+ return -1 ;
665
+ }
666
+
667
+ static vk_buffer ggml_vk_create_buffer (size_t size, vk::MemoryPropertyFlags req_flags, vk::MemoryPropertyFlags desired_flags = vk::MemoryPropertyFlags(0 )) {
656
668
#ifdef GGML_VULKAN_DEBUG
657
- std::cerr << " ggml_vk_create_buffer(" << size << " , " << to_string (req_flags) << " )" << std::endl;
669
+ std::cerr << " ggml_vk_create_buffer(" << size << " , " << to_string (req_flags) << " , " << to_string (desired_flags) << " )" << std::endl;
658
670
#endif
659
671
GGML_ASSERT (size > 0 );
660
672
@@ -676,17 +688,15 @@ static vk_buffer ggml_vk_create_buffer(size_t size, vk::MemoryPropertyFlags req_
676
688
677
689
vk::PhysicalDeviceMemoryProperties mem_props = vk_device.physical_device .getMemoryProperties ();
678
690
679
- uint32_t memory_type_index = UINT32_MAX;
680
-
681
- for (uint32_t i = 0 ; i < mem_props.memoryTypeCount ; ++i) {
682
- vk::MemoryType memory_type = mem_props.memoryTypes [i];
683
- if ((mem_req.memoryTypeBits & ((uint64_t )1 << i)) && (req_flags & memory_type.propertyFlags ) == req_flags && mem_props.memoryHeaps [memory_type.heapIndex ].size >= mem_req.size ) {
684
- memory_type_index = i;
685
- break ;
686
- }
691
+ uint32_t memory_type_index = -1 ;
692
+ if (desired_flags) {
693
+ memory_type_index = find_properties (&mem_props, &mem_req, req_flags | desired_flags);
694
+ }
695
+ if (memory_type_index == -1 ) {
696
+ memory_type_index = find_properties (&mem_props, &mem_req, req_flags);
687
697
}
688
698
689
- if (memory_type_index >= mem_props. memoryTypeCount ) {
699
+ if (memory_type_index == - 1 ) {
690
700
throw vk::OutOfDeviceMemoryError (" No suitable memory type found" );
691
701
}
692
702
@@ -712,9 +722,9 @@ static vk_buffer ggml_vk_create_buffer(size_t size, vk::MemoryPropertyFlags req_
712
722
return buf;
713
723
}
714
724
715
- static vk_buffer ggml_vk_create_buffer_check (size_t size, vk::MemoryPropertyFlags req_flags) {
725
+ static vk_buffer ggml_vk_create_buffer_check (size_t size, vk::MemoryPropertyFlags req_flags, vk::MemoryPropertyFlags desired_flags = vk::MemoryPropertyFlags( 0 ) ) {
716
726
try {
717
- return ggml_vk_create_buffer (size, req_flags);
727
+ return ggml_vk_create_buffer (size, req_flags, desired_flags );
718
728
} catch (const vk::SystemError& e) {
719
729
std::cerr << " ggml_vulkan: Memory allocation of size " << size << " failed." << std::endl;
720
730
std::cerr << " ggml_vulkan: " << e.what () << std::endl;
@@ -729,7 +739,10 @@ static vk_buffer ggml_vk_create_buffer_device(size_t size) {
729
739
} catch (const vk::SystemError& e) {
730
740
if (vk_device.uma ) {
731
741
// Fall back to host memory type
732
- buf = ggml_vk_create_buffer_check (size, vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent);
742
+ buf = ggml_vk_create_buffer_check (
743
+ size,
744
+ /* required */ vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent,
745
+ /* optional */ vk::MemoryPropertyFlagBits::eHostCached);
733
746
} else {
734
747
std::cerr << " ggml_vulkan: Device memory allocation of size " << size << " failed." << std::endl;
735
748
std::cerr << " ggml_vulkan: " << e.what () << std::endl;
@@ -1261,7 +1274,10 @@ static void * ggml_vk_host_malloc(size_t size) {
1261
1274
#ifdef GGML_VULKAN_DEBUG
1262
1275
std::cerr << " ggml_vk_host_malloc(" << size << " )" << std::endl;
1263
1276
#endif
1264
- vk_buffer buf = ggml_vk_create_buffer (size, vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent | vk::MemoryPropertyFlagBits::eHostCached);
1277
+ vk_buffer buf = ggml_vk_create_buffer (
1278
+ size,
1279
+ /* required */ vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent,
1280
+ /* optional */ vk::MemoryPropertyFlagBits::eHostCached);
1265
1281
1266
1282
if (!(buf.memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible)) {
1267
1283
fprintf (stderr, " WARNING: failed to allocate %.2f MB of pinned memory\n " ,
@@ -1408,7 +1424,10 @@ static void deferred_memcpy(void * dst, const void * src, size_t size, std::vect
1408
1424
static void ensure_sync_staging_buffer (size_t size) {
1409
1425
if (vk_sync_staging.size < size) {
1410
1426
ggml_vk_destroy_buffer (vk_sync_staging);
1411
- vk_sync_staging = ggml_vk_create_buffer_check (size, vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent | vk::MemoryPropertyFlagBits::eHostCached);
1427
+ vk_sync_staging = ggml_vk_create_buffer_check (
1428
+ size,
1429
+ /* required */ vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent,
1430
+ /* optional */ vk::MemoryPropertyFlagBits::eHostCached);
1412
1431
}
1413
1432
}
1414
1433
@@ -3812,7 +3831,10 @@ void ggml_vk_preallocate_buffers() {
3812
3831
std::cerr << " qx_size: " << vk_prealloc_size_qx << " qy_size: " << vk_prealloc_size_qy << " x_size: " << vk_prealloc_size_x << " y_size: " << vk_prealloc_size_y << " split_k_size: " << vk_prealloc_size_split_k << std::endl;
3813
3832
#endif
3814
3833
#if defined(GGML_VULKAN_RUN_TESTS)
3815
- vk_staging = ggml_vk_create_buffer_check (100ul * 1024ul * 1024ul , vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent | vk::MemoryPropertyFlagBits::eHostCached);
3834
+ vk_staging = ggml_vk_create_buffer_check (
3835
+ 100ul * 1024ul * 1024ul ,
3836
+ /* required */ vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent,
3837
+ /* optional */ vk::MemoryPropertyFlagBits::eHostCached);
3816
3838
ggml_vk_test_transfer (8192 * 1000 , false );
3817
3839
ggml_vk_test_transfer (8192 * 1000 , true );
3818
3840
@@ -3904,7 +3926,10 @@ void ggml_vk_preallocate_buffers() {
3904
3926
if (vk_staging.size > 0 ) {
3905
3927
ggml_vk_destroy_buffer (vk_staging);
3906
3928
}
3907
- vk_staging = ggml_vk_create_buffer_check (vk_staging_size, vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent | vk::MemoryPropertyFlagBits::eHostCached);
3929
+ vk_staging = ggml_vk_create_buffer_check (
3930
+ vk_staging_size,
3931
+ /* required */ vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent,
3932
+ /* optional */ vk::MemoryPropertyFlagBits::eHostCached);
3908
3933
}
3909
3934
}
3910
3935
0 commit comments