Skip to content

Commit aa303fe

Browse files
committed
kompute: add backend registry / device interfaces
Get in line with the other backends by supporting the newer backend/device registry interfaces. Signed-off-by: Sergio Lopez <[email protected]>
1 parent 2f8bd2b commit aa303fe

File tree

4 files changed

+219
-69
lines changed

4 files changed

+219
-69
lines changed

ggml/include/ggml-kompute.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ GGML_API bool ggml_backend_is_kompute(ggml_backend_t backend);
4141

4242
GGML_API ggml_backend_buffer_type_t ggml_backend_kompute_buffer_type(int device);
4343

44+
GGML_API ggml_backend_reg_t ggml_backend_kompute_reg(void);
45+
4446
#ifdef __cplusplus
4547
}
4648
#endif

ggml/src/ggml-backend.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -565,6 +565,10 @@ void * ggml_backend_reg_get_proc_address(ggml_backend_reg_t reg, const char * na
565565
#include "ggml-cann.h"
566566
#endif
567567

568+
#ifdef GGML_USE_KOMPUTE
569+
#include "ggml-kompute.h"
570+
#endif
571+
568572
struct ggml_backend_registry {
569573
std::vector<ggml_backend_reg_t> backends;
570574
std::vector<ggml_backend_dev_t> devices;
@@ -594,8 +598,9 @@ struct ggml_backend_registry {
594598
#ifdef GGML_USE_CANN
595599
register_backend(ggml_backend_cann_reg());
596600
#endif
597-
598-
// TODO: kompute
601+
#ifdef GGML_USE_KOMPUTE
602+
register_backend(ggml_backend_kompute_reg());
603+
#endif
599604

600605
register_backend(ggml_backend_cpu_reg());
601606
}

ggml/src/ggml-kompute.cpp

Lines changed: 210 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
#include <cstring>
4343
#include <iostream>
4444
#include <memory>
45+
#include <mutex>
4546
#include <stdexcept>
4647
#include <string>
4748
#include <unordered_map>
@@ -1323,17 +1324,7 @@ static void ggml_vk_cpy_f16_f32(Args&&... args) {
13231324
ggml_vk_cpy(spirv, 2, 4, std::forward<Args>(args)...);
13241325
}
13251326

1326-
static bool ggml_vk_supports_op(const struct ggml_tensor * op) {
1327-
switch (op->type) {
1328-
case GGML_TYPE_F16:
1329-
case GGML_TYPE_F32:
1330-
case GGML_TYPE_Q4_0:
1331-
case GGML_TYPE_Q4_1:
1332-
break;
1333-
default:
1334-
return false;
1335-
}
1336-
1327+
static bool ggml_backend_kompute_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) {
13371328
switch (op->op) {
13381329
case GGML_OP_UNARY:
13391330
switch (ggml_get_unary_op(op)) {
@@ -1410,6 +1401,8 @@ static bool ggml_vk_supports_op(const struct ggml_tensor * op) {
14101401
;
14111402
}
14121403
return false;
1404+
1405+
GGML_UNUSED(dev);
14131406
}
14141407

14151408
static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph * gf) {
@@ -1458,10 +1451,12 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
14581451

14591452
any_commands_recorded = true;
14601453

1454+
/* Do we still need this?
14611455
if (!ggml_vk_supports_op(dst)) {
14621456
fprintf(stderr, "%s: error: unsupported op '%s'\n", __func__, ggml_op_desc(dst));
14631457
GGML_ABORT("unsupported op");
14641458
}
1459+
*/
14651460

14661461
const int32_t ne00 = src0 ? src0->ne[0] : 0;
14671462
const int32_t ne01 = src0 ? src0->ne[1] : 0;
@@ -1921,7 +1916,7 @@ ggml_backend_buffer_type_t ggml_backend_kompute_buffer_type(int device) {
19211916
for (const auto & dev : devices) {
19221917
vec.push_back({
19231918
/* .iface = */ ggml_backend_kompute_buffer_type_interface,
1924-
/* .device = */ nullptr,
1919+
/* .device = */ ggml_backend_reg_dev_get(ggml_backend_kompute_reg(), 0),
19251920
/* .context = */ new ggml_backend_kompute_buffer_type_context(dev.index, dev.bufferAlignment, dev.maxAlloc)
19261921
});
19271922
}
@@ -1964,16 +1959,6 @@ static ggml_status ggml_backend_kompute_graph_compute(ggml_backend_t backend, st
19641959
return GGML_STATUS_SUCCESS;
19651960
}
19661961

1967-
static bool ggml_backend_kompute_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) {
1968-
GGML_UNUSED(backend);
1969-
return ggml_vk_supports_op(op);
1970-
}
1971-
1972-
static bool ggml_backend_kompute_supports_buft(ggml_backend_t backend, ggml_backend_buffer_type_t buft) {
1973-
GGML_UNUSED(backend);
1974-
return buft->iface.get_name == ggml_backend_kompute_buffer_type_get_name;
1975-
}
1976-
19771962
static struct ggml_backend_i kompute_backend_i = {
19781963
/* .get_name = */ ggml_backend_kompute_name,
19791964
/* .free = */ ggml_backend_kompute_free,
@@ -1987,8 +1972,8 @@ static struct ggml_backend_i kompute_backend_i = {
19871972
/* .graph_plan_update = */ NULL,
19881973
/* .graph_plan_compute = */ NULL,
19891974
/* .graph_compute = */ ggml_backend_kompute_graph_compute,
1990-
/* .supports_op = */ ggml_backend_kompute_supports_op,
1991-
/* .supports_buft = */ ggml_backend_kompute_supports_buft,
1975+
/* .supports_op = */ NULL,
1976+
/* .supports_buft = */ NULL,
19921977
/* .offload_op = */ NULL,
19931978
/* .event_record = */ NULL,
19941979
/* .event_wait = */ NULL,
@@ -2006,7 +1991,7 @@ ggml_backend_t ggml_backend_kompute_init(int device) {
20061991
ggml_backend_t kompute_backend = new ggml_backend {
20071992
/* .guid = */ ggml_backend_kompute_guid(),
20081993
/* .interface = */ kompute_backend_i,
2009-
/* .device = */ nullptr,
1994+
/* .device = */ ggml_backend_reg_dev_get(ggml_backend_kompute_reg(), 0),
20101995
/* .context = */ s_kompute_context,
20111996
};
20121997

@@ -2016,3 +2001,203 @@ ggml_backend_t ggml_backend_kompute_init(int device) {
20162001
bool ggml_backend_is_kompute(ggml_backend_t backend) {
20172002
return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_kompute_guid());
20182003
}
2004+
2005+
int ggml_backend_kompute_get_device_count() {
2006+
auto devices = ggml_vk_available_devices_internal(0);
2007+
return devices.size();
2008+
}
2009+
2010+
void ggml_backend_kompute_get_device_description(int device, char * description, size_t description_size) {
2011+
std::vector<vk::PhysicalDevice> physical_devices;
2012+
try {
2013+
physical_devices = komputeManager()->listDevices();
2014+
} catch (vk::SystemError & err) {
2015+
std::cerr << __func__ << ": Vulkan exception: " << err.what() << "\n";
2016+
GGML_ABORT("");
2017+
}
2018+
2019+
GGML_ASSERT(device < physical_devices.size());
2020+
2021+
const auto & physical_device = physical_devices[device];
2022+
VkPhysicalDeviceProperties dev_props = physical_device.getProperties();
2023+
2024+
auto devices = ggml_vk_available_devices_internal(0);
2025+
snprintf(description, description_size, "%s", dev_props.deviceName);
2026+
}
2027+
2028+
void ggml_backend_kompute_get_device_memory(int device, size_t * free, size_t * total) {
2029+
std::vector<vk::PhysicalDevice> physical_devices;
2030+
try {
2031+
physical_devices = komputeManager()->listDevices();
2032+
} catch (vk::SystemError & err) {
2033+
std::cerr << __func__ << ": Vulkan exception: " << err.what() << "\n";
2034+
GGML_ABORT("");
2035+
}
2036+
2037+
GGML_ASSERT(device < physical_devices.size());
2038+
2039+
const auto & physical_device = physical_devices[device];
2040+
2041+
vk::PhysicalDeviceMemoryProperties memprops = physical_device.getMemoryProperties();
2042+
2043+
for (const vk::MemoryHeap& heap : memprops.memoryHeaps) {
2044+
if (heap.flags & vk::MemoryHeapFlagBits::eDeviceLocal) {
2045+
*total = heap.size;
2046+
*free = heap.size;
2047+
break;
2048+
}
2049+
}
2050+
}
2051+
2052+
//////////////////////////
2053+
2054+
struct ggml_backend_kompute_device_context {
2055+
int device;
2056+
std::string name;
2057+
std::string description;
2058+
};
2059+
2060+
static const char * ggml_backend_kompute_device_get_name(ggml_backend_dev_t dev) {
2061+
ggml_backend_kompute_device_context * ctx = (ggml_backend_kompute_device_context *)dev->context;
2062+
return ctx->name.c_str();
2063+
}
2064+
2065+
static const char * ggml_backend_kompute_device_get_description(ggml_backend_dev_t dev) {
2066+
ggml_backend_kompute_device_context * ctx = (ggml_backend_kompute_device_context *)dev->context;
2067+
return ctx->description.c_str();
2068+
}
2069+
2070+
static void ggml_backend_kompute_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
2071+
ggml_backend_kompute_device_context * ctx = (ggml_backend_kompute_device_context *)dev->context;
2072+
ggml_backend_kompute_get_device_memory(ctx->device, free, total);
2073+
}
2074+
2075+
static ggml_backend_buffer_type_t ggml_backend_kompute_device_get_buffer_type(ggml_backend_dev_t dev) {
2076+
ggml_backend_kompute_device_context * ctx = (ggml_backend_kompute_device_context *)dev->context;
2077+
return ggml_backend_kompute_buffer_type(ctx->device);
2078+
}
2079+
2080+
static bool ggml_backend_kompute_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
2081+
if (buft->iface.get_name != ggml_backend_kompute_buffer_type_get_name) {
2082+
return false;
2083+
}
2084+
2085+
ggml_backend_kompute_device_context * ctx = (ggml_backend_kompute_device_context *)dev->context;
2086+
ggml_backend_kompute_buffer_type_context * buft_ctx = (ggml_backend_kompute_buffer_type_context *)buft->context;
2087+
2088+
return buft_ctx->device == ctx->device;
2089+
}
2090+
2091+
//TODO
2092+
/**
2093+
static ggml_backend_buffer_type_t ggml_backend_kompute_device_get_host_buffer_type(ggml_backend_dev_t dev) {
2094+
GGML_ABORT("Unimplemented");
2095+
return ggml_backend_kompute_host_buffer_type();
2096+
}
2097+
*/
2098+
2099+
static enum ggml_backend_dev_type ggml_backend_kompute_device_get_type(ggml_backend_dev_t dev) {
2100+
GGML_UNUSED(dev);
2101+
return GGML_BACKEND_DEVICE_TYPE_GPU_FULL;
2102+
}
2103+
2104+
static void ggml_backend_kompute_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) {
2105+
props->name = ggml_backend_kompute_device_get_name(dev);
2106+
props->description = ggml_backend_kompute_device_get_description(dev);
2107+
props->type = ggml_backend_kompute_device_get_type(dev);
2108+
ggml_backend_kompute_device_get_memory(dev, &props->memory_free, &props->memory_total);
2109+
props->caps = {
2110+
/* async */ false,
2111+
/* host_buffer */ false,
2112+
/* events */ false,
2113+
};
2114+
}
2115+
2116+
static ggml_backend_t ggml_backend_kompute_device_init(ggml_backend_dev_t dev, const char * params) {
2117+
GGML_UNUSED(params);
2118+
ggml_backend_kompute_device_context * ctx = (ggml_backend_kompute_device_context *)dev->context;
2119+
return ggml_backend_kompute_init(ctx->device);
2120+
}
2121+
2122+
static bool ggml_backend_kompute_device_offload_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
2123+
const int min_batch_size = 32;
2124+
2125+
return (op->ne[1] >= min_batch_size && op->op != GGML_OP_GET_ROWS) ||
2126+
(op->ne[2] >= min_batch_size && op->op == GGML_OP_MUL_MAT_ID);
2127+
2128+
GGML_UNUSED(dev);
2129+
}
2130+
2131+
static const struct ggml_backend_device_i ggml_backend_kompute_device_i = {
2132+
/* .get_name = */ ggml_backend_kompute_device_get_name,
2133+
/* .get_description = */ ggml_backend_kompute_device_get_description,
2134+
/* .get_memory = */ ggml_backend_kompute_device_get_memory,
2135+
/* .get_type = */ ggml_backend_kompute_device_get_type,
2136+
/* .get_props = */ ggml_backend_kompute_device_get_props,
2137+
/* .init_backend = */ ggml_backend_kompute_device_init,
2138+
/* .get_buffer_type = */ ggml_backend_kompute_device_get_buffer_type,
2139+
/* .get_host_buffer_type = */ NULL,
2140+
/* .buffer_from_host_ptr = */ NULL,
2141+
/* .supports_op = */ ggml_backend_kompute_device_supports_op,
2142+
/* .supports_buft = */ ggml_backend_kompute_device_supports_buft,
2143+
/* .offload_op = */ ggml_backend_kompute_device_offload_op,
2144+
/* .event_new = */ NULL,
2145+
/* .event_free = */ NULL,
2146+
/* .event_synchronize = */ NULL,
2147+
};
2148+
2149+
static const char * ggml_backend_kompute_reg_get_name(ggml_backend_reg_t reg) {
2150+
GGML_UNUSED(reg);
2151+
return "Kompute";
2152+
}
2153+
2154+
static size_t ggml_backend_kompute_reg_get_device_count(ggml_backend_reg_t reg) {
2155+
GGML_UNUSED(reg);
2156+
return ggml_backend_kompute_get_device_count();
2157+
}
2158+
2159+
static ggml_backend_dev_t ggml_backend_kompute_reg_get_device(ggml_backend_reg_t reg, size_t device) {
2160+
static std::vector<ggml_backend_dev_t> devices;
2161+
2162+
static bool initialized = false;
2163+
2164+
{
2165+
static std::mutex mutex;
2166+
std::lock_guard<std::mutex> lock(mutex);
2167+
if (!initialized) {
2168+
for (size_t i = 0; i < ggml_backend_kompute_get_device_count(); i++) {
2169+
ggml_backend_kompute_device_context * ctx = new ggml_backend_kompute_device_context;
2170+
char desc[256];
2171+
ggml_backend_kompute_get_device_description(i, desc, sizeof(desc));
2172+
ctx->device = i;
2173+
ctx->name = "Kompute" + std::to_string(i);
2174+
ctx->description = desc;
2175+
devices.push_back(new ggml_backend_device {
2176+
/* .iface = */ ggml_backend_kompute_device_i,
2177+
/* .reg = */ reg,
2178+
/* .context = */ ctx,
2179+
});
2180+
}
2181+
initialized = true;
2182+
}
2183+
}
2184+
2185+
GGML_ASSERT(device < devices.size());
2186+
return devices[device];
2187+
}
2188+
2189+
static const struct ggml_backend_reg_i ggml_backend_kompute_reg_i = {
2190+
/* .get_name = */ ggml_backend_kompute_reg_get_name,
2191+
/* .get_device_count = */ ggml_backend_kompute_reg_get_device_count,
2192+
/* .get_device = */ ggml_backend_kompute_reg_get_device,
2193+
/* .get_proc_address = */ NULL,
2194+
};
2195+
2196+
ggml_backend_reg_t ggml_backend_kompute_reg() {
2197+
static ggml_backend_reg reg = {
2198+
/* .iface = */ ggml_backend_kompute_reg_i,
2199+
/* .context = */ nullptr,
2200+
};
2201+
2202+
return &reg;
2203+
}

0 commit comments

Comments
 (0)