@@ -1926,7 +1926,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
1926
1926
if (device->fp16 ) {
1927
1927
device_extensions.push_back (" VK_KHR_shader_float16_int8" );
1928
1928
}
1929
- device->name = device-> properties . deviceName . data ( );
1929
+ device->name = GGML_VK_NAME + std::to_string (idx );
1930
1930
1931
1931
device_create_info = {
1932
1932
vk::DeviceCreateFlags (),
@@ -1953,7 +1953,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
1953
1953
1954
1954
device->buffer_type = {
1955
1955
/* .iface = */ ggml_backend_vk_buffer_type_interface,
1956
- /* .device = */ nullptr ,
1956
+ /* .device = */ ggml_backend_reg_dev_get ( ggml_backend_vk_reg (), idx) ,
1957
1957
/* .context = */ new ggml_backend_vk_buffer_type_context{ device->name , device },
1958
1958
};
1959
1959
@@ -6363,7 +6363,7 @@ ggml_backend_buffer_type_t ggml_backend_vk_host_buffer_type() {
6363
6363
/* .get_alloc_size = */ ggml_backend_cpu_buffer_type ()->iface .get_alloc_size ,
6364
6364
/* .is_host = */ ggml_backend_cpu_buffer_type ()->iface .is_host ,
6365
6365
},
6366
- /* .device = */ nullptr ,
6366
+ /* .device = */ ggml_backend_reg_dev_get ( ggml_backend_vk_reg (), 0 ) ,
6367
6367
/* .context = */ nullptr ,
6368
6368
};
6369
6369
@@ -6566,9 +6566,135 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
6566
6566
UNUSED (backend);
6567
6567
}
6568
6568
6569
- static bool ggml_backend_vk_supports_op (ggml_backend_t backend, const ggml_tensor * op) {
6570
- // ggml_backend_vk_context * ctx = (ggml_backend_vk_context *) backend->context;
6569
+ // TODO: enable async and synchronize
6570
+ static ggml_backend_i ggml_backend_vk_interface = {
6571
+ /* .get_name = */ ggml_backend_vk_name,
6572
+ /* .free = */ ggml_backend_vk_free,
6573
+ /* .get_default_buffer_type = */ ggml_backend_vk_get_default_buffer_type,
6574
+ /* .set_tensor_async = */ NULL , // ggml_backend_vk_set_tensor_async,
6575
+ /* .get_tensor_async = */ NULL , // ggml_backend_vk_get_tensor_async,
6576
+ /* .cpy_tensor_async = */ NULL , // ggml_backend_vk_cpy_tensor_async,
6577
+ /* .synchronize = */ NULL , // ggml_backend_vk_synchronize,
6578
+ /* .graph_plan_create = */ NULL ,
6579
+ /* .graph_plan_free = */ NULL ,
6580
+ /* .graph_plan_update = */ NULL ,
6581
+ /* .graph_plan_compute = */ NULL ,
6582
+ /* .graph_compute = */ ggml_backend_vk_graph_compute,
6583
+ /* .supports_op = */ NULL ,
6584
+ /* .supports_buft = */ NULL ,
6585
+ /* .offload_op = */ NULL ,
6586
+ /* .event_record = */ NULL ,
6587
+ /* .event_wait = */ NULL ,
6588
+ };
6589
+
6590
+ static ggml_guid_t ggml_backend_vk_guid () {
6591
+ static ggml_guid guid = { 0xb8 , 0xf7 , 0x4f , 0x86 , 0x40 , 0x3c , 0xe1 , 0x02 , 0x91 , 0xc8 , 0xdd , 0xe9 , 0x02 , 0x3f , 0xc0 , 0x2b };
6592
+ return &guid;
6593
+ }
6594
+
6595
+ ggml_backend_t ggml_backend_vk_init (size_t dev_num) {
6596
+ VK_LOG_DEBUG (" ggml_backend_vk_init(" << dev_num << " )" );
6597
+
6598
+ ggml_backend_vk_context * ctx = new ggml_backend_vk_context;
6599
+ ggml_vk_init (ctx, dev_num);
6600
+
6601
+ ggml_backend_t vk_backend = new ggml_backend {
6602
+ /* .guid = */ ggml_backend_vk_guid (),
6603
+ /* .interface = */ ggml_backend_vk_interface,
6604
+ /* .device = */ ggml_backend_reg_dev_get (ggml_backend_vk_reg (), dev_num),
6605
+ /* .context = */ ctx,
6606
+ };
6607
+
6608
+ return vk_backend;
6609
+ }
6610
+
6611
+ bool ggml_backend_is_vk (ggml_backend_t backend) {
6612
+ return backend != NULL && ggml_guid_matches (backend->guid , ggml_backend_vk_guid ());
6613
+ }
6614
+
6615
+ int ggml_backend_vk_get_device_count () {
6616
+ return ggml_vk_get_device_count ();
6617
+ }
6618
+
6619
+ void ggml_backend_vk_get_device_description (int device, char * description, size_t description_size) {
6620
+ GGML_ASSERT (device < (int ) vk_instance.device_indices .size ());
6621
+ int dev_idx = vk_instance.device_indices [device];
6622
+ ggml_vk_get_device_description (dev_idx, description, description_size);
6623
+ }
6624
+
6625
+ void ggml_backend_vk_get_device_memory (int device, size_t * free, size_t * total) {
6626
+ GGML_ASSERT (device < (int ) vk_instance.device_indices .size ());
6627
+
6628
+ vk::PhysicalDevice vkdev = vk_instance.instance .enumeratePhysicalDevices ()[vk_instance.device_indices [device]];
6629
+
6630
+ vk::PhysicalDeviceMemoryProperties memprops = vkdev.getMemoryProperties ();
6631
+
6632
+ for (const vk::MemoryHeap& heap : memprops.memoryHeaps ) {
6633
+ if (heap.flags & vk::MemoryHeapFlagBits::eDeviceLocal) {
6634
+ *total = heap.size ;
6635
+ *free = heap.size ;
6636
+ break ;
6637
+ }
6638
+ }
6639
+ }
6640
+
6641
+ // ////////////////////////
6642
+
6643
+ struct ggml_backend_vk_device_context {
6644
+ int device;
6645
+ std::string name;
6646
+ std::string description;
6647
+ };
6648
+
6649
+ static const char * ggml_backend_vk_device_get_name (ggml_backend_dev_t dev) {
6650
+ ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context ;
6651
+ return ctx->name .c_str ();
6652
+ }
6653
+
6654
+ static const char * ggml_backend_vk_device_get_description (ggml_backend_dev_t dev) {
6655
+ ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context ;
6656
+ return ctx->description .c_str ();
6657
+ }
6658
+
6659
+ static void ggml_backend_vk_device_get_memory (ggml_backend_dev_t device, size_t * free, size_t * total) {
6660
+ ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)device->context ;
6661
+ ggml_backend_vk_get_device_memory (ctx->device , free, total);
6662
+ }
6663
+
6664
+ static ggml_backend_buffer_type_t ggml_backend_vk_device_get_buffer_type (ggml_backend_dev_t dev) {
6665
+ ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context ;
6666
+ return ggml_backend_vk_buffer_type (ctx->device );
6667
+ }
6668
+
6669
+ static ggml_backend_buffer_type_t ggml_backend_vk_device_get_host_buffer_type (ggml_backend_dev_t dev) {
6670
+ UNUSED (dev);
6671
+ return ggml_backend_vk_host_buffer_type ();
6672
+ }
6571
6673
6674
+ static enum ggml_backend_dev_type ggml_backend_vk_device_get_type (ggml_backend_dev_t dev) {
6675
+ UNUSED (dev);
6676
+ return GGML_BACKEND_DEVICE_TYPE_GPU_FULL;
6677
+ }
6678
+
6679
+ static void ggml_backend_vk_device_get_props (ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) {
6680
+ props->name = ggml_backend_vk_device_get_name (dev);
6681
+ props->description = ggml_backend_vk_device_get_description (dev);
6682
+ props->type = ggml_backend_vk_device_get_type (dev);
6683
+ ggml_backend_vk_device_get_memory (dev, &props->memory_free , &props->memory_total );
6684
+ props->caps = {
6685
+ /* async */ false ,
6686
+ /* host_buffer */ true ,
6687
+ /* events */ false ,
6688
+ };
6689
+ }
6690
+
6691
+ static ggml_backend_t ggml_backend_vk_device_init (ggml_backend_dev_t dev, const char * params) {
6692
+ UNUSED (params);
6693
+ ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context ;
6694
+ return ggml_backend_vk_init (ctx->device );
6695
+ }
6696
+
6697
+ static bool ggml_backend_vk_device_supports_op (ggml_backend_dev_t dev, const ggml_tensor * op) {
6572
6698
switch (op->op ) {
6573
6699
case GGML_OP_UNARY:
6574
6700
switch (ggml_get_unary_op (op)) {
@@ -6686,97 +6812,101 @@ static bool ggml_backend_vk_supports_op(ggml_backend_t backend, const ggml_tenso
6686
6812
return false ;
6687
6813
}
6688
6814
6689
- UNUSED (backend);
6690
- }
6691
-
6692
- static bool ggml_backend_vk_offload_op (ggml_backend_t backend, const ggml_tensor * op) {
6693
- const int min_batch_size = 32 ;
6694
-
6695
- return (op->ne [1 ] >= min_batch_size && op->op != GGML_OP_GET_ROWS) ||
6696
- (op->ne [2 ] >= min_batch_size && op->op == GGML_OP_MUL_MAT_ID);
6697
-
6698
- UNUSED (backend);
6815
+ UNUSED (dev);
6699
6816
}
6700
6817
6701
- static bool ggml_backend_vk_supports_buft ( ggml_backend_t backend , ggml_backend_buffer_type_t buft) {
6818
+ static bool ggml_backend_vk_device_supports_buft ( ggml_backend_dev_t dev , ggml_backend_buffer_type_t buft) {
6702
6819
if (buft->iface .get_name != ggml_backend_vk_buffer_type_name) {
6703
6820
return false ;
6704
6821
}
6705
6822
6823
+ ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context ;
6706
6824
ggml_backend_vk_buffer_type_context * buft_ctx = (ggml_backend_vk_buffer_type_context *)buft->context ;
6707
- ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context ;
6708
-
6709
- return buft_ctx->device == ctx->device ;
6710
- }
6711
-
6712
- // TODO: enable async and synchronize
6713
- static ggml_backend_i ggml_backend_vk_interface = {
6714
- /* .get_name = */ ggml_backend_vk_name,
6715
- /* .free = */ ggml_backend_vk_free,
6716
- /* .get_default_buffer_type = */ ggml_backend_vk_get_default_buffer_type,
6717
- /* .set_tensor_async = */ NULL , // ggml_backend_vk_set_tensor_async,
6718
- /* .get_tensor_async = */ NULL , // ggml_backend_vk_get_tensor_async,
6719
- /* .cpy_tensor_async = */ NULL , // ggml_backend_vk_cpy_tensor_async,
6720
- /* .synchronize = */ NULL , // ggml_backend_vk_synchronize,
6721
- /* .graph_plan_create = */ NULL ,
6722
- /* .graph_plan_free = */ NULL ,
6723
- /* .graph_plan_update = */ NULL ,
6724
- /* .graph_plan_compute = */ NULL ,
6725
- /* .graph_compute = */ ggml_backend_vk_graph_compute,
6726
- /* .supports_op = */ ggml_backend_vk_supports_op,
6727
- /* .supports_buft = */ ggml_backend_vk_supports_buft,
6728
- /* .offload_op = */ ggml_backend_vk_offload_op,
6729
- /* .event_record = */ NULL ,
6730
- /* .event_wait = */ NULL ,
6731
- };
6732
6825
6733
- static ggml_guid_t ggml_backend_vk_guid () {
6734
- static ggml_guid guid = { 0xb8 , 0xf7 , 0x4f , 0x86 , 0x40 , 0x3c , 0xe1 , 0x02 , 0x91 , 0xc8 , 0xdd , 0xe9 , 0x02 , 0x3f , 0xc0 , 0x2b };
6735
- return &guid;
6826
+ return buft_ctx->device ->idx == ctx->device ;
6736
6827
}
6737
6828
6738
- ggml_backend_t ggml_backend_vk_init ( size_t dev_num ) {
6739
- VK_LOG_DEBUG ( " ggml_backend_vk_init( " << dev_num << " ) " ) ;
6829
+ static bool ggml_backend_vk_device_offload_op ( ggml_backend_dev_t dev, const ggml_tensor * op ) {
6830
+ const int min_batch_size = 32 ;
6740
6831
6741
- ggml_backend_vk_context * ctx = new ggml_backend_vk_context;
6742
- ggml_vk_init (ctx, dev_num );
6832
+ return (op-> ne [ 1 ] >= min_batch_size && op-> op != GGML_OP_GET_ROWS) ||
6833
+ (op-> ne [ 2 ] >= min_batch_size && op-> op == GGML_OP_MUL_MAT_ID );
6743
6834
6744
- ggml_backend_t vk_backend = new ggml_backend {
6745
- /* .guid = */ ggml_backend_vk_guid (),
6746
- /* .interface = */ ggml_backend_vk_interface,
6747
- /* .device = */ nullptr ,
6748
- /* .context = */ ctx,
6749
- };
6835
+ UNUSED (dev);
6836
+ }
6837
+
6838
+ static const struct ggml_backend_device_i ggml_backend_vk_device_i = {
6839
+ /* .get_name = */ ggml_backend_vk_device_get_name,
6840
+ /* .get_description = */ ggml_backend_vk_device_get_description,
6841
+ /* .get_memory = */ ggml_backend_vk_device_get_memory,
6842
+ /* .get_type = */ ggml_backend_vk_device_get_type,
6843
+ /* .get_props = */ ggml_backend_vk_device_get_props,
6844
+ /* .init_backend = */ ggml_backend_vk_device_init,
6845
+ /* .get_buffer_type = */ ggml_backend_vk_device_get_buffer_type,
6846
+ /* .get_host_buffer_type = */ ggml_backend_vk_device_get_host_buffer_type,
6847
+ /* .buffer_from_host_ptr = */ NULL ,
6848
+ /* .supports_op = */ ggml_backend_vk_device_supports_op,
6849
+ /* .supports_buft = */ ggml_backend_vk_device_supports_buft,
6850
+ /* .offload_op = */ ggml_backend_vk_device_offload_op,
6851
+ /* .event_new = */ NULL ,
6852
+ /* .event_free = */ NULL ,
6853
+ /* .event_synchronize = */ NULL ,
6854
+ };
6750
6855
6751
- return vk_backend;
6856
+ static const char * ggml_backend_vk_reg_get_name (ggml_backend_reg_t reg) {
6857
+ UNUSED (reg);
6858
+ return GGML_VK_NAME;
6752
6859
}
6753
6860
6754
- bool ggml_backend_is_vk (ggml_backend_t backend) {
6755
- return backend != NULL && ggml_guid_matches (backend->guid , ggml_backend_vk_guid ());
6861
+ static size_t ggml_backend_vk_reg_get_device_count (ggml_backend_reg_t reg) {
6862
+ UNUSED (reg);
6863
+ return ggml_backend_vk_get_device_count ();
6756
6864
}
6757
6865
6758
- int ggml_backend_vk_get_device_count () {
6759
- return ggml_vk_get_device_count ();
6760
- }
6866
+ static ggml_backend_dev_t ggml_backend_vk_reg_get_device (ggml_backend_reg_t reg, size_t device) {
6867
+ static std::vector<ggml_backend_dev_t > devices;
6761
6868
6762
- void ggml_backend_vk_get_device_description (int device, char * description, size_t description_size) {
6763
- ggml_vk_get_device_description (device, description, description_size);
6764
- }
6869
+ static bool initialized = false ;
6765
6870
6766
- void ggml_backend_vk_get_device_memory (int device, size_t * free, size_t * total) {
6767
- GGML_ASSERT (device < (int ) vk_instance.device_indices .size ());
6871
+ {
6872
+ static std::mutex mutex;
6873
+ std::lock_guard<std::mutex> lock (mutex);
6874
+ if (!initialized) {
6875
+ for (size_t i = 0 ; i < ggml_backend_vk_get_device_count (); i++) {
6876
+ ggml_backend_vk_device_context * ctx = new ggml_backend_vk_device_context;
6877
+ char desc[256 ];
6878
+ ggml_backend_vk_get_device_description (i, desc, sizeof (desc));
6879
+ ctx->device = i;
6880
+ ctx->name = GGML_VK_NAME + std::to_string (i);
6881
+ ctx->description = desc;
6882
+ devices.push_back (new ggml_backend_device {
6883
+ /* .iface = */ ggml_backend_vk_device_i,
6884
+ /* .reg = */ reg,
6885
+ /* .context = */ ctx,
6886
+ });
6887
+ }
6888
+ initialized = true ;
6889
+ }
6890
+ }
6768
6891
6769
- vk::PhysicalDevice vkdev = vk_instance.instance .enumeratePhysicalDevices ()[vk_instance.device_indices [device]];
6892
+ GGML_ASSERT (device < devices.size ());
6893
+ return devices[device];
6894
+ }
6770
6895
6771
- vk::PhysicalDeviceMemoryProperties memprops = vkdev.getMemoryProperties ();
6896
+ static const struct ggml_backend_reg_i ggml_backend_vk_reg_i = {
6897
+ /* .get_name = */ ggml_backend_vk_reg_get_name,
6898
+ /* .get_device_count = */ ggml_backend_vk_reg_get_device_count,
6899
+ /* .get_device = */ ggml_backend_vk_reg_get_device,
6900
+ /* .get_proc_address = */ NULL ,
6901
+ };
6772
6902
6773
- for ( const vk::MemoryHeap& heap : memprops. memoryHeaps ) {
6774
- if (heap. flags & vk::MemoryHeapFlagBits::eDeviceLocal) {
6775
- *total = heap. size ;
6776
- *free = heap. size ;
6777
- break ;
6778
- }
6779
- }
6903
+ ggml_backend_reg_t ggml_backend_vk_reg ( ) {
6904
+ static ggml_backend_reg reg = {
6905
+ /* .iface = */ ggml_backend_vk_reg_i,
6906
+ /* .context = */ nullptr ,
6907
+ } ;
6908
+
6909
+ return ®
6780
6910
}
6781
6911
6782
6912
// Extension availability
0 commit comments