@@ -6779,6 +6779,155 @@ void ggml_backend_vk_get_device_memory(int device, size_t * free, size_t * total
6779
6779
}
6780
6780
}
6781
6781
6782
+ // ////////////////////////
6783
+
6784
+ struct ggml_backend_vk_device_context {
6785
+ int device;
6786
+ std::string name;
6787
+ std::string description;
6788
+ };
6789
+
6790
+ static const char * ggml_backend_vk_device_get_name (ggml_backend_dev_t dev) {
6791
+ ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context ;
6792
+ return ctx->name .c_str ();
6793
+ }
6794
+
6795
+ static const char * ggml_backend_vk_device_get_description (ggml_backend_dev_t dev) {
6796
+ ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context ;
6797
+ return ctx->description .c_str ();
6798
+ }
6799
+
6800
+ static void ggml_backend_vk_device_get_memory (ggml_backend_dev_t device, size_t * free, size_t * total) {
6801
+ ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)device->context ;
6802
+ ggml_backend_vk_get_device_memory (ctx->device , free, total);
6803
+ }
6804
+
6805
+ static ggml_backend_buffer_type_t ggml_backend_vk_device_get_buffer_type (ggml_backend_dev_t dev) {
6806
+ ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context ;
6807
+ return ggml_backend_vk_buffer_type (ctx->device );
6808
+ }
6809
+
6810
+ static ggml_backend_buffer_type_t ggml_backend_vk_device_get_host_buffer_type (ggml_backend_dev_t dev) {
6811
+ UNUSED (dev);
6812
+ return ggml_backend_vk_host_buffer_type ();
6813
+ }
6814
+
6815
+ static enum ggml_backend_dev_type ggml_backend_vk_device_get_type (ggml_backend_dev_t dev) {
6816
+ UNUSED (dev);
6817
+ return GGML_BACKEND_DEVICE_TYPE_GPU_FULL;
6818
+ }
6819
+
6820
+ static void ggml_backend_vk_device_get_props (ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) {
6821
+ props->name = ggml_backend_vk_device_get_name (dev);
6822
+ props->description = ggml_backend_vk_device_get_description (dev);
6823
+ props->type = ggml_backend_vk_device_get_type (dev);
6824
+ ggml_backend_vk_device_get_memory (dev, &props->memory_free , &props->memory_total );
6825
+ props->caps = {
6826
+ /* async */ false ,
6827
+ /* host_buffer */ true ,
6828
+ /* events */ false ,
6829
+ };
6830
+ }
6831
+
6832
+ static ggml_backend_t ggml_backend_vk_device_init (ggml_backend_dev_t dev, const char * params) {
6833
+ UNUSED (params);
6834
+ ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context ;
6835
+ return ggml_backend_vk_init (ctx->device );
6836
+ }
6837
+
6838
+ static bool ggml_backend_vk_device_supports_op (ggml_backend_dev_t dev, const ggml_tensor * op) {
6839
+ // TODO: move here
6840
+ UNUSED (dev);
6841
+ return ggml_backend_vk_supports_op (nullptr , op);
6842
+ }
6843
+
6844
+ static bool ggml_backend_vk_device_supports_buft (ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
6845
+ // TODO: move here
6846
+ UNUSED (dev);
6847
+ return ggml_backend_vk_supports_buft (nullptr , buft);
6848
+ }
6849
+
6850
+ static bool ggml_backend_vk_device_offload_op (ggml_backend_dev_t dev, const ggml_tensor * op) {
6851
+ // TODO: move here
6852
+ UNUSED (dev);
6853
+ return ggml_backend_vk_offload_op (nullptr , op);
6854
+ }
6855
+
6856
+ static const struct ggml_backend_device_i ggml_backend_vk_device_i = {
6857
+ /* .get_name = */ ggml_backend_vk_device_get_name,
6858
+ /* .get_description = */ ggml_backend_vk_device_get_description,
6859
+ /* .get_memory = */ ggml_backend_vk_device_get_memory,
6860
+ /* .get_type = */ ggml_backend_vk_device_get_type,
6861
+ /* .get_props = */ ggml_backend_vk_device_get_props,
6862
+ /* .init_backend = */ ggml_backend_vk_device_init,
6863
+ /* .get_buffer_type = */ ggml_backend_vk_device_get_buffer_type,
6864
+ /* .get_host_buffer_type = */ ggml_backend_vk_device_get_host_buffer_type,
6865
+ /* .buffer_from_host_ptr = */ NULL ,
6866
+ /* .supports_op = */ ggml_backend_vk_device_supports_op,
6867
+ /* .supports_buft = */ ggml_backend_vk_device_supports_buft,
6868
+ /* .offload_op = */ ggml_backend_vk_device_offload_op,
6869
+ /* .event_new = */ NULL ,
6870
+ /* .event_free = */ NULL ,
6871
+ /* .event_synchronize = */ NULL ,
6872
+ };
6873
+
6874
+ static const char * ggml_backend_vk_reg_get_name (ggml_backend_reg_t reg) {
6875
+ UNUSED (reg);
6876
+ return GGML_VK_NAME;
6877
+ }
6878
+
6879
+ static size_t ggml_backend_vk_reg_get_device_count (ggml_backend_reg_t reg) {
6880
+ UNUSED (reg);
6881
+ return ggml_backend_vk_get_device_count ();
6882
+ }
6883
+
6884
+ static ggml_backend_dev_t ggml_backend_vk_reg_get_device (ggml_backend_reg_t reg, size_t device) {
6885
+ static std::vector<ggml_backend_dev_t > devices;
6886
+
6887
+ static bool initialized = false ;
6888
+
6889
+ {
6890
+ static std::mutex mutex;
6891
+ std::lock_guard<std::mutex> lock (mutex);
6892
+ if (!initialized) {
6893
+ for (size_t i = 0 ; i < ggml_backend_vk_get_device_count (); i++) {
6894
+ ggml_backend_vk_device_context * ctx = new ggml_backend_vk_device_context;
6895
+ ctx->device = i;
6896
+ char desc[256 ];
6897
+ ggml_backend_vk_get_device_description (i, desc, sizeof (desc));
6898
+ ctx->name = GGML_VK_NAME + std::to_string (i);
6899
+ ctx->description = desc;
6900
+ devices.push_back (new ggml_backend_device {
6901
+ /* .iface = */ ggml_backend_vk_device_i,
6902
+ /* .reg = */ reg,
6903
+ /* .context = */ ctx,
6904
+ });
6905
+ }
6906
+ initialized = true ;
6907
+ }
6908
+ }
6909
+
6910
+ GGML_ASSERT (device < devices.size ());
6911
+ return devices[device];
6912
+ }
6913
+
6914
+ static const struct ggml_backend_reg_i ggml_backend_vk_reg_i = {
6915
+ /* .get_name = */ ggml_backend_vk_reg_get_name,
6916
+ /* .get_device_count = */ ggml_backend_vk_reg_get_device_count,
6917
+ /* .get_device = */ ggml_backend_vk_reg_get_device,
6918
+ /* .get_proc_address = */ NULL ,
6919
+ /* .set_log_callback = */ NULL ,
6920
+ };
6921
+
6922
+ ggml_backend_reg_t ggml_backend_vk_reg () {
6923
+ static ggml_backend_reg reg = {
6924
+ /* .iface = */ ggml_backend_vk_reg_i,
6925
+ /* .context = */ nullptr ,
6926
+ };
6927
+
6928
+ return ®
6929
+ }
6930
+
6782
6931
// Extension availability
6783
6932
static bool ggml_vk_instance_validation_ext_available (const std::vector<vk::ExtensionProperties>& instance_extensions) {
6784
6933
#ifdef GGML_VULKAN_VALIDATE
0 commit comments