Skip to content

vulkan : add backend registry / device interfaces #9721

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Oct 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions ggml/include/ggml-vulkan.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ GGML_API ggml_backend_buffer_type_t ggml_backend_vk_buffer_type(size_t dev_num);
// pinned host buffer for use with the CPU backend for faster copies between CPU and GPU
GGML_API ggml_backend_buffer_type_t ggml_backend_vk_host_buffer_type(void);

GGML_API ggml_backend_reg_t ggml_backend_vk_reg(void);

#ifdef __cplusplus
}
#endif
9 changes: 8 additions & 1 deletion ggml/src/ggml-backend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -538,6 +538,10 @@ void * ggml_backend_reg_get_proc_address(ggml_backend_reg_t reg, const char * na
#include "ggml-metal.h"
#endif

#ifdef GGML_USE_VULKAN
#include "ggml-vulkan.h"
#endif

#ifdef GGML_USE_BLAS
#include "ggml-blas.h"
#endif
Expand All @@ -557,14 +561,17 @@ struct ggml_backend_registry {
#ifdef GGML_USE_METAL
register_backend(ggml_backend_metal_reg());
#endif
#ifdef GGML_USE_VULKAN
register_backend(ggml_backend_vk_reg());
#endif
#ifdef GGML_USE_BLAS
register_backend(ggml_backend_blas_reg());
#endif
#ifdef GGML_USE_RPC
register_backend(ggml_backend_rpc_reg());
#endif

// TODO: sycl, vulkan, kompute, cann
// TODO: sycl, kompute, cann

register_backend(ggml_backend_cpu_reg());
}
Expand Down
278 changes: 204 additions & 74 deletions ggml/src/ggml-vulkan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1941,7 +1941,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
if (device->fp16) {
device_extensions.push_back("VK_KHR_shader_float16_int8");
}
device->name = device->properties.deviceName.data();
device->name = GGML_VK_NAME + std::to_string(idx);

device_create_info = {
vk::DeviceCreateFlags(),
Expand All @@ -1968,7 +1968,7 @@ static vk_device ggml_vk_get_device(size_t idx) {

device->buffer_type = {
/* .iface = */ ggml_backend_vk_buffer_type_interface,
/* .device = */ nullptr,
/* .device = */ ggml_backend_reg_dev_get(ggml_backend_vk_reg(), idx),
/* .context = */ new ggml_backend_vk_buffer_type_context{ device->name, device },
};

Expand Down Expand Up @@ -6378,7 +6378,7 @@ ggml_backend_buffer_type_t ggml_backend_vk_host_buffer_type() {
/* .get_alloc_size = */ ggml_backend_cpu_buffer_type()->iface.get_alloc_size,
/* .is_host = */ ggml_backend_cpu_buffer_type()->iface.is_host,
},
/* .device = */ nullptr,
/* .device = */ ggml_backend_reg_dev_get(ggml_backend_vk_reg(), 0),
/* .context = */ nullptr,
};

Expand Down Expand Up @@ -6581,9 +6581,135 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
UNUSED(backend);
}

static bool ggml_backend_vk_supports_op(ggml_backend_t backend, const ggml_tensor * op) {
// ggml_backend_vk_context * ctx = (ggml_backend_vk_context *) backend->context;
// TODO: enable async and synchronize
static ggml_backend_i ggml_backend_vk_interface = {
/* .get_name = */ ggml_backend_vk_name,
/* .free = */ ggml_backend_vk_free,
/* .get_default_buffer_type = */ ggml_backend_vk_get_default_buffer_type,
/* .set_tensor_async = */ NULL, // ggml_backend_vk_set_tensor_async,
/* .get_tensor_async = */ NULL, // ggml_backend_vk_get_tensor_async,
/* .cpy_tensor_async = */ NULL, // ggml_backend_vk_cpy_tensor_async,
/* .synchronize = */ NULL, // ggml_backend_vk_synchronize,
/* .graph_plan_create = */ NULL,
/* .graph_plan_free = */ NULL,
/* .graph_plan_update = */ NULL,
/* .graph_plan_compute = */ NULL,
/* .graph_compute = */ ggml_backend_vk_graph_compute,
/* .supports_op = */ NULL,
/* .supports_buft = */ NULL,
/* .offload_op = */ NULL,
/* .event_record = */ NULL,
/* .event_wait = */ NULL,
};

static ggml_guid_t ggml_backend_vk_guid() {
static ggml_guid guid = { 0xb8, 0xf7, 0x4f, 0x86, 0x40, 0x3c, 0xe1, 0x02, 0x91, 0xc8, 0xdd, 0xe9, 0x02, 0x3f, 0xc0, 0x2b };
return &guid;
}

ggml_backend_t ggml_backend_vk_init(size_t dev_num) {
VK_LOG_DEBUG("ggml_backend_vk_init(" << dev_num << ")");

ggml_backend_vk_context * ctx = new ggml_backend_vk_context;
ggml_vk_init(ctx, dev_num);

ggml_backend_t vk_backend = new ggml_backend {
/* .guid = */ ggml_backend_vk_guid(),
/* .interface = */ ggml_backend_vk_interface,
/* .device = */ ggml_backend_reg_dev_get(ggml_backend_vk_reg(), dev_num),
/* .context = */ ctx,
};

return vk_backend;
}

bool ggml_backend_is_vk(ggml_backend_t backend) {
return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_vk_guid());
}

int ggml_backend_vk_get_device_count() {
return ggml_vk_get_device_count();
}

void ggml_backend_vk_get_device_description(int device, char * description, size_t description_size) {
GGML_ASSERT(device < (int) vk_instance.device_indices.size());
int dev_idx = vk_instance.device_indices[device];
ggml_vk_get_device_description(dev_idx, description, description_size);
}

void ggml_backend_vk_get_device_memory(int device, size_t * free, size_t * total) {
GGML_ASSERT(device < (int) vk_instance.device_indices.size());

vk::PhysicalDevice vkdev = vk_instance.instance.enumeratePhysicalDevices()[vk_instance.device_indices[device]];

vk::PhysicalDeviceMemoryProperties memprops = vkdev.getMemoryProperties();

for (const vk::MemoryHeap& heap : memprops.memoryHeaps) {
if (heap.flags & vk::MemoryHeapFlagBits::eDeviceLocal) {
*total = heap.size;
*free = heap.size;
break;
}
}
}

//////////////////////////

struct ggml_backend_vk_device_context {
int device;
std::string name;
std::string description;
};

static const char * ggml_backend_vk_device_get_name(ggml_backend_dev_t dev) {
ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
return ctx->name.c_str();
}

static const char * ggml_backend_vk_device_get_description(ggml_backend_dev_t dev) {
ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
return ctx->description.c_str();
}

static void ggml_backend_vk_device_get_memory(ggml_backend_dev_t device, size_t * free, size_t * total) {
ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)device->context;
ggml_backend_vk_get_device_memory(ctx->device, free, total);
}

static ggml_backend_buffer_type_t ggml_backend_vk_device_get_buffer_type(ggml_backend_dev_t dev) {
ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
return ggml_backend_vk_buffer_type(ctx->device);
}

static ggml_backend_buffer_type_t ggml_backend_vk_device_get_host_buffer_type(ggml_backend_dev_t dev) {
UNUSED(dev);
return ggml_backend_vk_host_buffer_type();
}

static enum ggml_backend_dev_type ggml_backend_vk_device_get_type(ggml_backend_dev_t dev) {
UNUSED(dev);
return GGML_BACKEND_DEVICE_TYPE_GPU_FULL;
}

static void ggml_backend_vk_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) {
props->name = ggml_backend_vk_device_get_name(dev);
props->description = ggml_backend_vk_device_get_description(dev);
props->type = ggml_backend_vk_device_get_type(dev);
ggml_backend_vk_device_get_memory(dev, &props->memory_free, &props->memory_total);
props->caps = {
/* async */ false,
/* host_buffer */ true,
/* events */ false,
};
}

static ggml_backend_t ggml_backend_vk_device_init(ggml_backend_dev_t dev, const char * params) {
UNUSED(params);
ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
return ggml_backend_vk_init(ctx->device);
}

static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
switch (op->op) {
case GGML_OP_UNARY:
switch (ggml_get_unary_op(op)) {
Expand Down Expand Up @@ -6701,97 +6827,101 @@ static bool ggml_backend_vk_supports_op(ggml_backend_t backend, const ggml_tenso
return false;
}

UNUSED(backend);
}

static bool ggml_backend_vk_offload_op(ggml_backend_t backend, const ggml_tensor * op) {
const int min_batch_size = 32;

return (op->ne[1] >= min_batch_size && op->op != GGML_OP_GET_ROWS) ||
(op->ne[2] >= min_batch_size && op->op == GGML_OP_MUL_MAT_ID);

UNUSED(backend);
UNUSED(dev);
}

static bool ggml_backend_vk_supports_buft(ggml_backend_t backend, ggml_backend_buffer_type_t buft) {
static bool ggml_backend_vk_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
if (buft->iface.get_name != ggml_backend_vk_buffer_type_name) {
return false;
}

ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
ggml_backend_vk_buffer_type_context * buft_ctx = (ggml_backend_vk_buffer_type_context *)buft->context;
ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;

return buft_ctx->device == ctx->device;
}

// TODO: enable async and synchronize
static ggml_backend_i ggml_backend_vk_interface = {
/* .get_name = */ ggml_backend_vk_name,
/* .free = */ ggml_backend_vk_free,
/* .get_default_buffer_type = */ ggml_backend_vk_get_default_buffer_type,
/* .set_tensor_async = */ NULL, // ggml_backend_vk_set_tensor_async,
/* .get_tensor_async = */ NULL, // ggml_backend_vk_get_tensor_async,
/* .cpy_tensor_async = */ NULL, // ggml_backend_vk_cpy_tensor_async,
/* .synchronize = */ NULL, // ggml_backend_vk_synchronize,
/* .graph_plan_create = */ NULL,
/* .graph_plan_free = */ NULL,
/* .graph_plan_update = */ NULL,
/* .graph_plan_compute = */ NULL,
/* .graph_compute = */ ggml_backend_vk_graph_compute,
/* .supports_op = */ ggml_backend_vk_supports_op,
/* .supports_buft = */ ggml_backend_vk_supports_buft,
/* .offload_op = */ ggml_backend_vk_offload_op,
/* .event_record = */ NULL,
/* .event_wait = */ NULL,
};

static ggml_guid_t ggml_backend_vk_guid() {
static ggml_guid guid = { 0xb8, 0xf7, 0x4f, 0x86, 0x40, 0x3c, 0xe1, 0x02, 0x91, 0xc8, 0xdd, 0xe9, 0x02, 0x3f, 0xc0, 0x2b };
return &guid;
return buft_ctx->device->idx == ctx->device;
}

ggml_backend_t ggml_backend_vk_init(size_t dev_num) {
VK_LOG_DEBUG("ggml_backend_vk_init(" << dev_num << ")");
static bool ggml_backend_vk_device_offload_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
const int min_batch_size = 32;

ggml_backend_vk_context * ctx = new ggml_backend_vk_context;
ggml_vk_init(ctx, dev_num);
return (op->ne[1] >= min_batch_size && op->op != GGML_OP_GET_ROWS) ||
(op->ne[2] >= min_batch_size && op->op == GGML_OP_MUL_MAT_ID);

ggml_backend_t vk_backend = new ggml_backend {
/* .guid = */ ggml_backend_vk_guid(),
/* .interface = */ ggml_backend_vk_interface,
/* .device = */ nullptr,
/* .context = */ ctx,
};
UNUSED(dev);
}

static const struct ggml_backend_device_i ggml_backend_vk_device_i = {
/* .get_name = */ ggml_backend_vk_device_get_name,
/* .get_description = */ ggml_backend_vk_device_get_description,
/* .get_memory = */ ggml_backend_vk_device_get_memory,
/* .get_type = */ ggml_backend_vk_device_get_type,
/* .get_props = */ ggml_backend_vk_device_get_props,
/* .init_backend = */ ggml_backend_vk_device_init,
/* .get_buffer_type = */ ggml_backend_vk_device_get_buffer_type,
/* .get_host_buffer_type = */ ggml_backend_vk_device_get_host_buffer_type,
/* .buffer_from_host_ptr = */ NULL,
/* .supports_op = */ ggml_backend_vk_device_supports_op,
/* .supports_buft = */ ggml_backend_vk_device_supports_buft,
/* .offload_op = */ ggml_backend_vk_device_offload_op,
/* .event_new = */ NULL,
/* .event_free = */ NULL,
/* .event_synchronize = */ NULL,
};

return vk_backend;
static const char * ggml_backend_vk_reg_get_name(ggml_backend_reg_t reg) {
UNUSED(reg);
return GGML_VK_NAME;
}

bool ggml_backend_is_vk(ggml_backend_t backend) {
return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_vk_guid());
static size_t ggml_backend_vk_reg_get_device_count(ggml_backend_reg_t reg) {
UNUSED(reg);
return ggml_backend_vk_get_device_count();
}

int ggml_backend_vk_get_device_count() {
return ggml_vk_get_device_count();
}
static ggml_backend_dev_t ggml_backend_vk_reg_get_device(ggml_backend_reg_t reg, size_t device) {
static std::vector<ggml_backend_dev_t> devices;

void ggml_backend_vk_get_device_description(int device, char * description, size_t description_size) {
ggml_vk_get_device_description(device, description, description_size);
}
static bool initialized = false;

void ggml_backend_vk_get_device_memory(int device, size_t * free, size_t * total) {
GGML_ASSERT(device < (int) vk_instance.device_indices.size());
{
static std::mutex mutex;
std::lock_guard<std::mutex> lock(mutex);
if (!initialized) {
for (size_t i = 0; i < ggml_backend_vk_get_device_count(); i++) {
ggml_backend_vk_device_context * ctx = new ggml_backend_vk_device_context;
char desc[256];
ggml_backend_vk_get_device_description(i, desc, sizeof(desc));
ctx->device = i;
ctx->name = GGML_VK_NAME + std::to_string(i);
ctx->description = desc;
devices.push_back(new ggml_backend_device {
/* .iface = */ ggml_backend_vk_device_i,
/* .reg = */ reg,
/* .context = */ ctx,
});
}
initialized = true;
}
}

vk::PhysicalDevice vkdev = vk_instance.instance.enumeratePhysicalDevices()[vk_instance.device_indices[device]];
GGML_ASSERT(device < devices.size());
return devices[device];
}

vk::PhysicalDeviceMemoryProperties memprops = vkdev.getMemoryProperties();
static const struct ggml_backend_reg_i ggml_backend_vk_reg_i = {
/* .get_name = */ ggml_backend_vk_reg_get_name,
/* .get_device_count = */ ggml_backend_vk_reg_get_device_count,
/* .get_device = */ ggml_backend_vk_reg_get_device,
/* .get_proc_address = */ NULL,
};

for (const vk::MemoryHeap& heap : memprops.memoryHeaps) {
if (heap.flags & vk::MemoryHeapFlagBits::eDeviceLocal) {
*total = heap.size;
*free = heap.size;
break;
}
}
ggml_backend_reg_t ggml_backend_vk_reg() {
static ggml_backend_reg reg = {
/* .iface = */ ggml_backend_vk_reg_i,
/* .context = */ nullptr,
};

return &reg;
}

// Extension availability
Expand Down
Loading
Loading