42
42
#include < cstring>
43
43
#include < iostream>
44
44
#include < memory>
45
+ #include < mutex>
45
46
#include < stdexcept>
46
47
#include < string>
47
48
#include < unordered_map>
@@ -1323,17 +1324,7 @@ static void ggml_vk_cpy_f16_f32(Args&&... args) {
1323
1324
ggml_vk_cpy (spirv, 2 , 4 , std::forward<Args>(args)...);
1324
1325
}
1325
1326
1326
- static bool ggml_vk_supports_op (const struct ggml_tensor * op) {
1327
- switch (op->type ) {
1328
- case GGML_TYPE_F16:
1329
- case GGML_TYPE_F32:
1330
- case GGML_TYPE_Q4_0:
1331
- case GGML_TYPE_Q4_1:
1332
- break ;
1333
- default :
1334
- return false ;
1335
- }
1336
-
1327
+ static bool ggml_backend_kompute_device_supports_op (ggml_backend_dev_t dev, const struct ggml_tensor * op) {
1337
1328
switch (op->op ) {
1338
1329
case GGML_OP_UNARY:
1339
1330
switch (ggml_get_unary_op (op)) {
@@ -1410,6 +1401,8 @@ static bool ggml_vk_supports_op(const struct ggml_tensor * op) {
1410
1401
;
1411
1402
}
1412
1403
return false ;
1404
+
1405
+ GGML_UNUSED (dev);
1413
1406
}
1414
1407
1415
1408
static void ggml_vk_graph_compute (struct ggml_kompute_context * ctx, struct ggml_cgraph * gf) {
@@ -1458,10 +1451,12 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
1458
1451
1459
1452
any_commands_recorded = true ;
1460
1453
1454
+ /* Do we still need this?
1461
1455
if (!ggml_vk_supports_op(dst)) {
1462
1456
fprintf(stderr, "%s: error: unsupported op '%s'\n", __func__, ggml_op_desc(dst));
1463
1457
GGML_ABORT("unsupported op");
1464
1458
}
1459
+ */
1465
1460
1466
1461
const int32_t ne00 = src0 ? src0->ne [0 ] : 0 ;
1467
1462
const int32_t ne01 = src0 ? src0->ne [1 ] : 0 ;
@@ -1921,7 +1916,7 @@ ggml_backend_buffer_type_t ggml_backend_kompute_buffer_type(int device) {
1921
1916
for (const auto & dev : devices) {
1922
1917
vec.push_back ({
1923
1918
/* .iface = */ ggml_backend_kompute_buffer_type_interface,
1924
- /* .device = */ nullptr ,
1919
+ /* .device = */ ggml_backend_reg_dev_get ( ggml_backend_kompute_reg (), 0 ) ,
1925
1920
/* .context = */ new ggml_backend_kompute_buffer_type_context (dev.index , dev.bufferAlignment , dev.maxAlloc )
1926
1921
});
1927
1922
}
@@ -1964,16 +1959,6 @@ static ggml_status ggml_backend_kompute_graph_compute(ggml_backend_t backend, st
1964
1959
return GGML_STATUS_SUCCESS;
1965
1960
}
1966
1961
1967
- static bool ggml_backend_kompute_supports_op (ggml_backend_t backend, const struct ggml_tensor * op) {
1968
- GGML_UNUSED (backend);
1969
- return ggml_vk_supports_op (op);
1970
- }
1971
-
1972
- static bool ggml_backend_kompute_supports_buft (ggml_backend_t backend, ggml_backend_buffer_type_t buft) {
1973
- GGML_UNUSED (backend);
1974
- return buft->iface .get_name == ggml_backend_kompute_buffer_type_get_name;
1975
- }
1976
-
1977
1962
static struct ggml_backend_i kompute_backend_i = {
1978
1963
/* .get_name = */ ggml_backend_kompute_name,
1979
1964
/* .free = */ ggml_backend_kompute_free,
@@ -1987,8 +1972,8 @@ static struct ggml_backend_i kompute_backend_i = {
1987
1972
/* .graph_plan_update = */ NULL ,
1988
1973
/* .graph_plan_compute = */ NULL ,
1989
1974
/* .graph_compute = */ ggml_backend_kompute_graph_compute,
1990
- /* .supports_op = */ ggml_backend_kompute_supports_op ,
1991
- /* .supports_buft = */ ggml_backend_kompute_supports_buft ,
1975
+ /* .supports_op = */ NULL ,
1976
+ /* .supports_buft = */ NULL ,
1992
1977
/* .offload_op = */ NULL ,
1993
1978
/* .event_record = */ NULL ,
1994
1979
/* .event_wait = */ NULL ,
@@ -2006,7 +1991,7 @@ ggml_backend_t ggml_backend_kompute_init(int device) {
2006
1991
ggml_backend_t kompute_backend = new ggml_backend {
2007
1992
/* .guid = */ ggml_backend_kompute_guid (),
2008
1993
/* .interface = */ kompute_backend_i,
2009
- /* .device = */ nullptr ,
1994
+ /* .device = */ ggml_backend_reg_dev_get ( ggml_backend_kompute_reg (), 0 ) ,
2010
1995
/* .context = */ s_kompute_context,
2011
1996
};
2012
1997
@@ -2016,3 +2001,203 @@ ggml_backend_t ggml_backend_kompute_init(int device) {
2016
2001
bool ggml_backend_is_kompute (ggml_backend_t backend) {
2017
2002
return backend != NULL && ggml_guid_matches (backend->guid , ggml_backend_kompute_guid ());
2018
2003
}
2004
+
2005
+ int ggml_backend_kompute_get_device_count () {
2006
+ auto devices = ggml_vk_available_devices_internal (0 );
2007
+ return devices.size ();
2008
+ }
2009
+
2010
+ void ggml_backend_kompute_get_device_description (int device, char * description, size_t description_size) {
2011
+ std::vector<vk::PhysicalDevice> physical_devices;
2012
+ try {
2013
+ physical_devices = komputeManager ()->listDevices ();
2014
+ } catch (vk::SystemError & err) {
2015
+ std::cerr << __func__ << " : Vulkan exception: " << err.what () << " \n " ;
2016
+ GGML_ABORT (" " );
2017
+ }
2018
+
2019
+ GGML_ASSERT (device < physical_devices.size ());
2020
+
2021
+ const auto & physical_device = physical_devices[device];
2022
+ VkPhysicalDeviceProperties dev_props = physical_device.getProperties ();
2023
+
2024
+ auto devices = ggml_vk_available_devices_internal (0 );
2025
+ snprintf (description, description_size, " %s" , dev_props.deviceName );
2026
+ }
2027
+
2028
+ void ggml_backend_kompute_get_device_memory (int device, size_t * free, size_t * total) {
2029
+ std::vector<vk::PhysicalDevice> physical_devices;
2030
+ try {
2031
+ physical_devices = komputeManager ()->listDevices ();
2032
+ } catch (vk::SystemError & err) {
2033
+ std::cerr << __func__ << " : Vulkan exception: " << err.what () << " \n " ;
2034
+ GGML_ABORT (" " );
2035
+ }
2036
+
2037
+ GGML_ASSERT (device < physical_devices.size ());
2038
+
2039
+ const auto & physical_device = physical_devices[device];
2040
+
2041
+ vk::PhysicalDeviceMemoryProperties memprops = physical_device.getMemoryProperties ();
2042
+
2043
+ for (const vk::MemoryHeap& heap : memprops.memoryHeaps ) {
2044
+ if (heap.flags & vk::MemoryHeapFlagBits::eDeviceLocal) {
2045
+ *total = heap.size ;
2046
+ *free = heap.size ;
2047
+ break ;
2048
+ }
2049
+ }
2050
+ }
2051
+
2052
+ // ////////////////////////
2053
+
2054
+ struct ggml_backend_kompute_device_context {
2055
+ int device;
2056
+ std::string name;
2057
+ std::string description;
2058
+ };
2059
+
2060
+ static const char * ggml_backend_kompute_device_get_name (ggml_backend_dev_t dev) {
2061
+ ggml_backend_kompute_device_context * ctx = (ggml_backend_kompute_device_context *)dev->context ;
2062
+ return ctx->name .c_str ();
2063
+ }
2064
+
2065
+ static const char * ggml_backend_kompute_device_get_description (ggml_backend_dev_t dev) {
2066
+ ggml_backend_kompute_device_context * ctx = (ggml_backend_kompute_device_context *)dev->context ;
2067
+ return ctx->description .c_str ();
2068
+ }
2069
+
2070
+ static void ggml_backend_kompute_device_get_memory (ggml_backend_dev_t dev, size_t * free, size_t * total) {
2071
+ ggml_backend_kompute_device_context * ctx = (ggml_backend_kompute_device_context *)dev->context ;
2072
+ ggml_backend_kompute_get_device_memory (ctx->device , free, total);
2073
+ }
2074
+
2075
+ static ggml_backend_buffer_type_t ggml_backend_kompute_device_get_buffer_type (ggml_backend_dev_t dev) {
2076
+ ggml_backend_kompute_device_context * ctx = (ggml_backend_kompute_device_context *)dev->context ;
2077
+ return ggml_backend_kompute_buffer_type (ctx->device );
2078
+ }
2079
+
2080
+ static bool ggml_backend_kompute_device_supports_buft (ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
2081
+ if (buft->iface .get_name != ggml_backend_kompute_buffer_type_get_name) {
2082
+ return false ;
2083
+ }
2084
+
2085
+ ggml_backend_kompute_device_context * ctx = (ggml_backend_kompute_device_context *)dev->context ;
2086
+ ggml_backend_kompute_buffer_type_context * buft_ctx = (ggml_backend_kompute_buffer_type_context *)buft->context ;
2087
+
2088
+ return buft_ctx->device == ctx->device ;
2089
+ }
2090
+
2091
+ // TODO
2092
+ /* *
2093
+ static ggml_backend_buffer_type_t ggml_backend_kompute_device_get_host_buffer_type(ggml_backend_dev_t dev) {
2094
+ GGML_ABORT("Unimplemented");
2095
+ return ggml_backend_kompute_host_buffer_type();
2096
+ }
2097
+ */
2098
+
2099
+ static enum ggml_backend_dev_type ggml_backend_kompute_device_get_type (ggml_backend_dev_t dev) {
2100
+ GGML_UNUSED (dev);
2101
+ return GGML_BACKEND_DEVICE_TYPE_GPU_FULL;
2102
+ }
2103
+
2104
+ static void ggml_backend_kompute_device_get_props (ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) {
2105
+ props->name = ggml_backend_kompute_device_get_name (dev);
2106
+ props->description = ggml_backend_kompute_device_get_description (dev);
2107
+ props->type = ggml_backend_kompute_device_get_type (dev);
2108
+ ggml_backend_kompute_device_get_memory (dev, &props->memory_free , &props->memory_total );
2109
+ props->caps = {
2110
+ /* async */ false ,
2111
+ /* host_buffer */ false ,
2112
+ /* events */ false ,
2113
+ };
2114
+ }
2115
+
2116
+ static ggml_backend_t ggml_backend_kompute_device_init (ggml_backend_dev_t dev, const char * params) {
2117
+ GGML_UNUSED (params);
2118
+ ggml_backend_kompute_device_context * ctx = (ggml_backend_kompute_device_context *)dev->context ;
2119
+ return ggml_backend_kompute_init (ctx->device );
2120
+ }
2121
+
2122
+ static bool ggml_backend_kompute_device_offload_op (ggml_backend_dev_t dev, const ggml_tensor * op) {
2123
+ const int min_batch_size = 32 ;
2124
+
2125
+ return (op->ne [1 ] >= min_batch_size && op->op != GGML_OP_GET_ROWS) ||
2126
+ (op->ne [2 ] >= min_batch_size && op->op == GGML_OP_MUL_MAT_ID);
2127
+
2128
+ GGML_UNUSED (dev);
2129
+ }
2130
+
2131
+ static const struct ggml_backend_device_i ggml_backend_kompute_device_i = {
2132
+ /* .get_name = */ ggml_backend_kompute_device_get_name,
2133
+ /* .get_description = */ ggml_backend_kompute_device_get_description,
2134
+ /* .get_memory = */ ggml_backend_kompute_device_get_memory,
2135
+ /* .get_type = */ ggml_backend_kompute_device_get_type,
2136
+ /* .get_props = */ ggml_backend_kompute_device_get_props,
2137
+ /* .init_backend = */ ggml_backend_kompute_device_init,
2138
+ /* .get_buffer_type = */ ggml_backend_kompute_device_get_buffer_type,
2139
+ /* .get_host_buffer_type = */ NULL ,
2140
+ /* .buffer_from_host_ptr = */ NULL ,
2141
+ /* .supports_op = */ ggml_backend_kompute_device_supports_op,
2142
+ /* .supports_buft = */ ggml_backend_kompute_device_supports_buft,
2143
+ /* .offload_op = */ ggml_backend_kompute_device_offload_op,
2144
+ /* .event_new = */ NULL ,
2145
+ /* .event_free = */ NULL ,
2146
+ /* .event_synchronize = */ NULL ,
2147
+ };
2148
+
2149
+ static const char * ggml_backend_kompute_reg_get_name (ggml_backend_reg_t reg) {
2150
+ GGML_UNUSED (reg);
2151
+ return " Kompute" ;
2152
+ }
2153
+
2154
+ static size_t ggml_backend_kompute_reg_get_device_count (ggml_backend_reg_t reg) {
2155
+ GGML_UNUSED (reg);
2156
+ return ggml_backend_kompute_get_device_count ();
2157
+ }
2158
+
2159
+ static ggml_backend_dev_t ggml_backend_kompute_reg_get_device (ggml_backend_reg_t reg, size_t device) {
2160
+ static std::vector<ggml_backend_dev_t > devices;
2161
+
2162
+ static bool initialized = false ;
2163
+
2164
+ {
2165
+ static std::mutex mutex;
2166
+ std::lock_guard<std::mutex> lock (mutex);
2167
+ if (!initialized) {
2168
+ for (size_t i = 0 ; i < ggml_backend_kompute_get_device_count (); i++) {
2169
+ ggml_backend_kompute_device_context * ctx = new ggml_backend_kompute_device_context;
2170
+ char desc[256 ];
2171
+ ggml_backend_kompute_get_device_description (i, desc, sizeof (desc));
2172
+ ctx->device = i;
2173
+ ctx->name = " Kompute" + std::to_string (i);
2174
+ ctx->description = desc;
2175
+ devices.push_back (new ggml_backend_device {
2176
+ /* .iface = */ ggml_backend_kompute_device_i,
2177
+ /* .reg = */ reg,
2178
+ /* .context = */ ctx,
2179
+ });
2180
+ }
2181
+ initialized = true ;
2182
+ }
2183
+ }
2184
+
2185
+ GGML_ASSERT (device < devices.size ());
2186
+ return devices[device];
2187
+ }
2188
+
2189
+ static const struct ggml_backend_reg_i ggml_backend_kompute_reg_i = {
2190
+ /* .get_name = */ ggml_backend_kompute_reg_get_name,
2191
+ /* .get_device_count = */ ggml_backend_kompute_reg_get_device_count,
2192
+ /* .get_device = */ ggml_backend_kompute_reg_get_device,
2193
+ /* .get_proc_address = */ NULL ,
2194
+ };
2195
+
2196
+ ggml_backend_reg_t ggml_backend_kompute_reg () {
2197
+ static ggml_backend_reg reg = {
2198
+ /* .iface = */ ggml_backend_kompute_reg_i,
2199
+ /* .context = */ nullptr ,
2200
+ };
2201
+
2202
+ return ®
2203
+ }
0 commit comments